// Copyright (c) Facebook, Inc. and its affiliates. // // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. #include "common.cuh" #include "kernels.cuh" #include #include #include #include #include #include #include #include #include #include #if CCCL_VERSION >= 2008002 #include #define CUB_REDUCTIONOP_MAX \ cuda::maximum<> {} #else #define CUB_REDUCTIONOP_MAX cub::Max() #endif #define HLF_MAX 65504 #define TH 1024 #define NUM 4 #define NUM_BLOCK 4096 __device__ static float fp4_dequantization_lut[8] = { 0.0f, // 0b000 0.005208333333f, // 0b001 0.66666667f, // 0b010 1.0f, // 0b011 0.33333333f, // 0b100 0.5f, // 0b101 0.16666667f, // 0b110 0.25f // 0b111 }; __device__ static float nf4_dequantization_lut[16] = { -1.0f, // 0b0000 -0.6961928009986877f, // 0b0001 -0.5250730514526367f, // 0b0010 -0.39491748809814453f, // 0b0011 -0.28444138169288635f, // 0b0100 -0.18477343022823334f, // 0b0101 -0.09105003625154495f, // 0b0110 0.0f, // 0b0111 0.07958029955625534f, // 0b1000 0.16093020141124725f, // 0b1001 0.24611230194568634f, // 0b1010 0.33791524171829224f, // 0b1011 0.44070982933044434f, // 0b1100 0.5626170039176941f, // 0b1101 0.7229568362236023f, // 0b1110 1.0f // 0b1111 }; // source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda __device__ float atomicMax(float* address, float val) { int* address_as_i = reinterpret_cast(address); int old = *address_as_i, assumed; do { assumed = old; old = atomicCAS(reinterpret_cast(address), assumed, __float_as_int(fmaxf(val, __int_as_float(assumed)))); } while (assumed != old); return __int_as_float(old); } __device__ __forceinline__ float dDequantizeFP4Tree(unsigned char val) { float sign = 1.0f - 2 * ((val & 0b1000) >> 3); return fp4_dequantization_lut[val & 0b111] * sign; } __device__ unsigned char dQuantizeFP4(float x) { // FP4 with bias of 3 // first bit is a sign // subnormals // 0b000 = 0 // 0b001 = 0.0625 // 0b110 = 2 // 0b111 = 3 // 0b100 = 4 // 0b101 = 6 // 0b010 = 8 // 0b011 = 12 // we do a binary search // the pivots are divided by 12 (the FP4 absmax) // since we assume input data is in [-1.0, 1.0] // !be careful here, its easy to make a mistake // that is difficult to notice if you add an extra // zero somewhere! int sign = x < 0 ? 0b1000 : 0b0000; x = fabsf(x); if (x > 0.29166667f) if (x > 0.583333f) if (x > 0.8333333f) return 0b0011 + sign; else return 0b0010 + sign; else if (x > 0.4166667f) return 0b101 + sign; else return 0b100 + sign; else if (x > 0.0859375f) if (x > 0.20833333f) return 0b0111 + sign; else return 0b0110 + sign; else if (x > 0.00260417f) return 0b0001 + sign; else return 0b0000 + sign; } __device__ __forceinline__ float dDequantizeNF4(unsigned char val) { return nf4_dequantization_lut[val & 0x0F]; } __device__ unsigned char dQuantizeNF4(float x) { // the values for this tree was generated by test_normal_map_tree // in the file tests/test_functional.py if (x > 0.03979014977812767f) if (x > 0.3893125355243683f) // 1 if (x > 0.6427869200706482f) // 11 if (x > 0.8614784181118011f) // 111 return 0b1111; else return 0b1110; else if (x > 0.5016634166240692f) // 110 return 0b1101; else return 0b1100; else if (x > 0.2035212516784668f) // 10 if (x > 0.2920137718319893f) // 101 return 0b1011; else return 0b1010; else if (x > 0.1202552504837513f) // 100 return 0b1001; else return 0b1000; else if (x > -0.33967943489551544f) // 0 if (x > -0.13791173323988914f) // 01 if (x > -0.045525018125772476f) // 011 return 0b0111; else return 0b0110; else if (x > -0.23460740596055984f) // 010 return 0b0101; else return 0b0100; else if (x > -0.6106329262256622f) // 00 if (x > -0.4599952697753906f) // 001 return 0b0011; else return 0b0010; else if (x > -0.8480964004993439f) // 000 return 0b0001; else return 0b0000; } // sign function for lion // taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA template __device__ int sgn(T val) { return (T(0) < val) - (val < T(0)); } template __device__ unsigned char dQuantize(float* smem_code, const float rand, float x) { int pivot = 127; int upper_pivot = 255; int lower_pivot = 0; float lower = -1.0f; float upper = 1.0f; float val = smem_code[pivot]; // i>>=1 = {32, 16, 8, 4, 2, 1} for (int i = 64; i > 0; i >>= 1) { if (x > val) { lower_pivot = pivot; lower = val; pivot += i; } else { upper_pivot = pivot; upper = val; pivot -= i; } val = smem_code[pivot]; } if (upper_pivot == 255) upper = smem_code[upper_pivot]; if (lower_pivot == 0) lower = smem_code[lower_pivot]; if (!STOCHASTIC) { if (x > val) { float midpoint = (upper + val) * 0.5f; if (x > midpoint) { return upper_pivot; } else return pivot; } else { float midpoint = (lower + val) * 0.5f; if (x < midpoint) return lower_pivot; else return pivot; } } else { if (x > val) { float dist_to_upper = fabsf(upper - x); float dist_full = upper - val; if (rand >= dist_to_upper / dist_full) return upper_pivot; else return pivot; } else { float dist_to_lower = fabsf(lower - x); float dist_full = val - lower; if (rand >= dist_to_lower / dist_full) return lower_pivot; else return pivot; } } } template __device__ __forceinline__ unsigned char quantize_2D(float* __restrict__ quadrants, float* __restrict__ const smem_code, float x) { int pivot = 127; int upper_pivot = 255; int lower_pivot = 0; float lower = SIGNED ? -1.0f : 0.0f; float upper = 1.0f; float midpoint; float val = quadrants[1]; int local_pivot = 1; int offset = 1; // i>>=1 = {32, 16, 8, 4, 2, 1} for (int i = 64; i > 0; i >>= 1) { if (x > val) { lower_pivot = pivot; lower = val; pivot += i; // val = i == 64 ? quadrants[2] : smem_code[pivot]; local_pivot += offset; } else { upper_pivot = pivot; upper = val; pivot -= i; // val = i == 64 ? quadrants[0] : smem_code[pivot]; local_pivot -= offset; } val = i >= 64 ? quadrants[local_pivot] : smem_code[pivot]; offset -= 1; } if (x > val) { midpoint = (upper + val) * 0.5f; if (x > midpoint) return upper_pivot; else return pivot; } else { midpoint = (lower + val) * 0.5f; if (x < midpoint) return lower_pivot; else return pivot; } } __launch_bounds__(TH, 4) __global__ void kQuantize(float* code, float* __restrict__ const A, unsigned char* out, const int n) { const int n_full = (NUM_BLOCK * (n / NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); int valid_items = (blockIdx.x + 1 == gridDim.x) ? n - (blockIdx.x * NUM_BLOCK) : NUM_BLOCK; const int base_idx = (blockIdx.x * NUM_BLOCK); float vals[NUM]; unsigned char qvals[NUM]; // const int lane_id = threadIdx.x % 2; typedef cub::BlockLoad LoadFloat; typedef cub::BlockStore StoreChar; __shared__ typename LoadFloat::TempStorage loadf; __shared__ typename StoreChar::TempStorage storec; __shared__ float smem_code[256]; //__shared__ float smem_code[2][257]; if (threadIdx.x < 256) { smem_code[threadIdx.x] = code[threadIdx.x]; // smem_code[0][threadIdx.x] = code[threadIdx.x]; // smem_code[1][threadIdx.x] = smem_code[0][threadIdx.x]; } for (unsigned int i = base_idx; i < n_full; i += gridDim.x * NUM_BLOCK) { // number of values already processed in blocks + // number of values already processed in this block + // rand_offset % mod value valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i; __syncthreads(); LoadFloat(loadf).Load(&(A[i]), vals, valid_items); #pragma unroll 4 for (int j = 0; j < NUM; j++) qvals[j] = dQuantize<0>(smem_code, 0.0f, vals[j]); __syncthreads(); StoreChar(storec).Store(&(out[i]), qvals, valid_items); } } template //__launch_bounds__(TH, 4) __global__ void kQuantizeBlockwise( float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, const int rand_offset, const int n ) { const int n_full = gridDim.x * BLOCK_SIZE; int valid_items = 0; const int base_idx = (blockIdx.x * BLOCK_SIZE); T vals[NUM_PER_TH]; float rand_vals[NUM_PER_TH]; unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH / 2 : NUM_PER_TH]; // float local_abs_max = -FLT_MAX; float local_abs_max = 0.0f; int local_rand_idx = 0; typedef cub::BlockLoad LoadT; typedef cub::BlockStore< unsigned char, BLOCK_SIZE / NUM_PER_TH, (DATA_TYPE > 0) ? NUM_PER_TH / 2 : NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar; typedef cub::BlockReduce BlockReduce; typedef cub::BlockLoad LoadFloat; __shared__ typename LoadT::TempStorage loadt; __shared__ typename LoadFloat::TempStorage loadf; __shared__ typename StoreChar::TempStorage storec; __shared__ typename BlockReduce::TempStorage reduce; __shared__ float smem_code[256]; __shared__ float smem_absmax_value[1]; if (DATA_TYPE == General8bit) for (int i = threadIdx.x; i < 256; i += blockDim.x) smem_code[i] = code[i]; for (int i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) { valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; local_abs_max = -FLT_MAX; __syncthreads(); LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f); // 1. compute local max // 2. broadcast local max // 3. normalize inputs and quantize #pragma unroll NUM_PER_TH for (int j = 0; j < NUM_PER_TH; j++) local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, CUB_REDUCTIONOP_MAX, valid_items); if (threadIdx.x == 0) { smem_absmax_value[0] = 1.0f / local_abs_max; absmax[i / BLOCK_SIZE] = local_abs_max; } __syncthreads(); local_abs_max = smem_absmax_value[0]; if (STOCHASTIC) { local_rand_idx = ((blockIdx.x * NUM_BLOCK) + (threadIdx.x * NUM) + rand_offset) % (1024 - 4); LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); } switch (DATA_TYPE) { case General8bit: #pragma unroll NUM_PER_TH for (int j = 0; j < NUM_PER_TH; j++) { if (!STOCHASTIC) qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j]) * local_abs_max); else qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j]) * local_abs_max); } break; case FP4: #pragma unroll NUM_PER_TH for (int j = 0; j < NUM_PER_TH / 2; j++) { qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4; qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max); } break; case NF4: #pragma unroll NUM_PER_TH for (int j = 0; j < NUM_PER_TH / 2; j++) { qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4; qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max); } break; } __syncthreads(); StoreChar(storec).Store( &(out[(DATA_TYPE > 0) ? i / 2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items + 1) / 2 : valid_items ); } } template __global__ void kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n) { const int n_load = (gridDim.x * TILE_SIZE); int valid_items_load = 0; int valid_items_store = 0; const int base_idx = (blockIdx.x * TILE_SIZE); T vals[NUM_PER_TH * ((DATA_TYPE > 0) ? 2 : 1)]; unsigned char qvals[NUM_PER_TH]; float local_abs_max = -FLT_MAX; typedef cub::BlockLoad LoadChar; typedef cub::BlockStore 0) ? 2 : 1), cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT; __shared__ typename LoadChar::TempStorage loadchar; __shared__ typename StoreT::TempStorage storet; for (int i = base_idx; i < n_load; i += gridDim.x * TILE_SIZE) { if (DATA_TYPE > 0) { valid_items_load = min(TILE_SIZE, (n + 1) / 2 - i); valid_items_store = min(TILE_SIZE * 2, n - i * 2); } else { valid_items_load = min(TILE_SIZE, n - i); valid_items_store = valid_items_load; } // Since blocksize will always be a power-of-2, we avoid more expensive // division by the blocksize and instead use a shift operation. // This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize. local_abs_max = __ldg(&absmax[(i + threadIdx.x * NUM_PER_TH) >> (31 - __clz(blocksize))]); __syncthreads(); LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); switch (DATA_TYPE) { case General8bit: // load code through read-only cache via __ldg #pragma unroll NUM_PER_TH for (int j = 0; j < NUM_PER_TH; j++) vals[j] = __ldg(&code[qvals[j]]) * local_abs_max; break; case FP4: #pragma unroll NUM_PER_TH for (int j = 0; j < NUM_PER_TH; j++) { vals[j * 2] = dDequantizeFP4Tree(qvals[j] >> 4) * local_abs_max; vals[j * 2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F) * local_abs_max; } break; case NF4: #pragma unroll NUM_PER_TH for (int j = 0; j < NUM_PER_TH; j++) { vals[j * 2] = dDequantizeNF4(qvals[j] >> 4) * local_abs_max; vals[j * 2 + 1] = dDequantizeNF4(qvals[j] & 0x0F) * local_abs_max; } break; } __syncthreads(); StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i * 2 : i]), vals, valid_items_store); } } __global__ void kDequantize(float* code, unsigned char* A, float* out, const int n) { const unsigned int numThreads = blockDim.x * gridDim.x; const int idx = (blockIdx.x * blockDim.x) + threadIdx.x; __shared__ float smem_code[256]; if (threadIdx.x < 256) { smem_code[threadIdx.x] = code[threadIdx.x]; } __syncthreads(); for (int i = idx; i < n; i += numThreads) { out[i] = smem_code[A[i]]; } } template __launch_bounds__(BLOCK_SIZE / NUM_VALS, 1) __global__ void kPreconditionOptimizer32bit2State( T* g, T* p, float* state1, float* state2, float* unorm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n ) { const int n_full = (BLOCK_SIZE * (n / BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS); int valid_items = 0; T g_vals[NUM_VALS]; float s1_vals[NUM_VALS]; float s2_vals[NUM_VALS]; const float correction1 = 1.0f / (1.0f - powf(beta1, step)); const float correction2 = 1.0f / (1.0f - powf(beta2, step)); typedef cub::BlockLoad Load; typedef cub::BlockLoad LoadFloat; typedef cub::BlockReduce BlockReduce; __shared__ union { typename Load::TempStorage load; typename LoadFloat::TempStorage loadf; typename BlockReduce::TempStorage reduce; } temp_storage; for (unsigned int i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) { valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; __syncthreads(); Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); __syncthreads(); LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); __syncthreads(); LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items, 0.0f); #pragma unroll NUM_VALS for (unsigned int j = 0; j < NUM_VALS; j++) g_vals[j] = gnorm_scale * ((float)g_vals[j]); #pragma unroll NUM_VALS for (unsigned int j = 0; j < NUM_VALS; j++) { switch (OPTIMIZER) { case ADAM: s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * ((float)g_vals[j])); s2_vals[j] = s2_vals[j] * beta2 + ((1.0f - beta2) * (((float)g_vals[j]) * ((float)g_vals[j]))); s1_vals[j] *= correction1; s2_vals[j] *= correction2; s1_vals[j] = s1_vals[j] / (sqrtf(s2_vals[j]) + eps); // update s1_vals[j] *= s1_vals[j]; // update l2 norm (update*update) break; } } #pragma unroll NUM_VALS - 1 for (unsigned int j = 1; j < NUM_VALS; j++) s1_vals[0] += s1_vals[j]; __syncthreads(); s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0]); if (threadIdx.x == 0) atomicAdd(&unorm[0], s1_vals[0]); __syncwarp(); } } #define NUM_PER_THREAD 4 template __launch_bounds__(TH, 1) __global__ void kOptimizer32bit2State( T* g, T* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n ) { const int n_full = ((TH * NUM_PER_THREAD) * (n / (TH * NUM_PER_THREAD))) + (n % (TH * NUM_PER_THREAD) == 0 ? 0 : (TH * NUM_PER_THREAD)); const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); int valid_items = 0; float update_scale = 0.0f; T g_vals[NUM_PER_THREAD]; T p_vals[NUM_PER_THREAD]; float s1_vals[NUM_PER_THREAD]; float s2_vals[NUM_PER_THREAD]; // AdEMAMix has an additional state buffer, which we packed // into state1. We need thread-local storage here for these. // TODO: Mark with [[maybe_unused]] after upgrade to min compiler. float s3_vals[NUM_PER_THREAD]; const float correction1 = 1.0f - powf(beta1, step); const float correction2 = sqrtf(1.0f - powf(beta2, step)); const float step_size = -lr * correction2 / correction1; if (max_unorm > 0.0f) { update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; if (update_scale > max_unorm * param_norm) { update_scale = (max_unorm * param_norm) / update_scale; } else { update_scale = 1.0f; } } else { update_scale = 1.0f; } typedef cub::BlockLoad Load; typedef cub::BlockStore Store; typedef cub::BlockLoad LoadFloat; typedef cub::BlockStore StoreFloat; __shared__ union { typename Load::TempStorage load; typename Store::TempStorage store; typename LoadFloat::TempStorage loadf; typename StoreFloat::TempStorage storef; } temp_storage; for (unsigned int i = base_idx; i < n_full; i += gridDim.x * TH * NUM_PER_THREAD) { valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i; __syncthreads(); Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); __syncthreads(); LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); __syncthreads(); LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items); __syncthreads(); Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); // Load additional state1 data for AdEMAMix // TODO: Make constexpr after updating min compiler if (OPTIMIZER == ADEMAMIX) { __syncthreads(); LoadFloat(temp_storage.loadf).Load(&(state1[n + i]), s3_vals, valid_items); } #pragma unroll 4 for (unsigned int j = 0; j < NUM_PER_THREAD; j++) g_vals[j] = gnorm_scale * ((float)g_vals[j]); #pragma unroll 4 for (unsigned int j = 0; j < NUM_PER_THREAD; j++) { switch (OPTIMIZER) { case ADEMAMIX: // m1 update: m1 = beta1 * m1 + (1-beta1) * g s1_vals[j] = (s1_vals[j] * beta1) + ((1.0f - beta1) * (float)g_vals[j]); // m2 update: m2 = m2 * beta3 + (1-beta3) * g s3_vals[j] = (s3_vals[j] * beta3) + ((1.0f - beta3) * (float)g_vals[j]); // nu update: nu = beta2 * nu + (1-beta2) * g^2 s2_vals[j] = (s2_vals[j] * beta2) + ((1.0f - beta2) * (float)g_vals[j] * (float)g_vals[j]); p_vals[j] = (float)p_vals[j] - lr * (((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / ((sqrtf(s2_vals[j]) / correction2) + eps)); if (weight_decay > 0.0f) p_vals[j] = ((float)p_vals[j]) * (1.0f - (lr * weight_decay)); break; case ADAM: if (!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) { s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * ((float)g_vals[j])); s2_vals[j] = s2_vals[j] * beta2 + ((1.0f - beta2) * (((float)g_vals[j]) * ((float)g_vals[j]))); p_vals[j] = ((float)p_vals[j]) + (update_scale * step_size * (s1_vals[j] / (sqrtf(s2_vals[j]) + (eps * correction2)))); if (weight_decay > 0.0f) p_vals[j] = ((float)p_vals[j]) * (1.0f - (lr * weight_decay)); } break; } } __syncthreads(); Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); __syncthreads(); StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); __syncthreads(); StoreFloat(temp_storage.storef).Store(&(state2[i]), s2_vals, valid_items); if (OPTIMIZER == ADEMAMIX) { __syncthreads(); StoreFloat(temp_storage.storef).Store(&(state1[n + i]), s3_vals, valid_items); } } } template __launch_bounds__(BLOCK_SIZE / NUM_VALS, 1) __global__ void kPreconditionOptimizer32bit1State( T* g, T* p, float* state1, float* unorm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n ) { const int n_full = (BLOCK_SIZE * (n / BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS); int valid_items = 0; T g_vals[NUM_VALS]; float s1_vals[NUM_VALS]; typedef cub::BlockLoad Load; typedef cub::BlockLoad LoadFloat; typedef cub::BlockReduce BlockReduce; __shared__ union { typename Load::TempStorage load; typename LoadFloat::TempStorage loadf; typename BlockReduce::TempStorage reduce; } temp_storage; for (unsigned int i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) { valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; __syncthreads(); Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); __syncthreads(); LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); #pragma unroll NUM_VALS for (unsigned int j = 0; j < NUM_VALS; j++) g_vals[j] = gnorm_scale * ((float)g_vals[j]); #pragma unroll NUM_VALS for (unsigned int j = 0; j < NUM_VALS; j++) { switch (OPTIMIZER) { case MOMENTUM: if (step == 1) s1_vals[j] = (float)g_vals[j]; // state update else s1_vals[j] = s1_vals[j] * beta1 + ((float)g_vals[j]); // state update s1_vals[j] = s1_vals[j] * s1_vals[j]; // update norm break; case LION: s1_vals[j] = s1_vals[j] * beta2 + ((1.0f - beta2) * (float)g_vals[j]); // state update break; case RMSPROP: s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * ((float)g_vals[j]) * ((float)g_vals[j])); // state update s1_vals[j] = __fdividef((float)g_vals[j], sqrtf(s1_vals[j]) + eps); // update value s1_vals[j] = s1_vals[j] * s1_vals[j]; // update norm break; case ADAGRAD: s1_vals[j] = s1_vals[j] + ((float)g_vals[j]) * ((float)g_vals[j]); // state update s1_vals[j] = __fdividef((float)g_vals[j], sqrtf(s1_vals[j]) + eps); // update value s1_vals[j] = s1_vals[j] * s1_vals[j]; // update norm break; } } #pragma unroll for (unsigned int j = 1; j < NUM_VALS; j++) s1_vals[0] += s1_vals[j]; __syncthreads(); s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0], valid_items); if (threadIdx.x == 0) atomicAdd(&unorm[0], s1_vals[0]); __syncwarp(); } } template __launch_bounds__(TH, 1) __global__ void kOptimizer32bit1State( T* g, T* p, float* state1, float* unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n ) { const int n_full = ((TH * NUM_PER_THREAD) * (n / (TH * NUM_PER_THREAD))) + (n % (TH * NUM_PER_THREAD) == 0 ? 0 : (TH * NUM_PER_THREAD)); const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); int valid_items = 0; float update_scale = 0.0f; if (max_unorm > 0.0f) { update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; if (update_scale > max_unorm * param_norm + eps) { update_scale = (max_unorm * param_norm + eps) / update_scale; } else { update_scale = 1.0f; } } else { update_scale = 1.0f; } T g_vals[NUM_PER_THREAD]; T p_vals[NUM_PER_THREAD]; float s1_vals[NUM_PER_THREAD]; typedef cub::BlockLoad Load; typedef cub::BlockStore Store; typedef cub::BlockLoad LoadFloat; typedef cub::BlockStore StoreFloat; __shared__ union { typename Load::TempStorage load; typename Store::TempStorage store; typename LoadFloat::TempStorage loadf; typename StoreFloat::TempStorage storef; } temp_storage; for (unsigned int i = base_idx; i < n_full; i += gridDim.x * TH * NUM_PER_THREAD) { valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i; __syncthreads(); Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); __syncthreads(); LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); __syncthreads(); Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); #pragma unroll 4 for (unsigned int j = 0; j < NUM_PER_THREAD; j++) { g_vals[j] = gnorm_scale * ((float)g_vals[j]); if (weight_decay > 0.0f) g_vals[j] = (float)g_vals[j] + (((float)p_vals[j]) * weight_decay); } #pragma unroll 4 for (unsigned int j = 0; j < NUM_PER_THREAD; j++) { if (!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) { switch (OPTIMIZER) { case MOMENTUM: if (step == 1) s1_vals[j] = (float)g_vals[j]; else s1_vals[j] = s1_vals[j] * beta1 + ((float)g_vals[j]); p_vals[j] = ((float)p_vals[j]) + update_scale * (-lr * (s1_vals[j])); break; case LION: p_vals[j] = ((float)p_vals[j]) - update_scale * (lr * sgn(((float)s1_vals[j]) * beta1 + ((1.0f - beta1) * ((float)g_vals[j])))); s1_vals[j] = s1_vals[j] * beta2 + ((1.0f - beta2) * ((float)g_vals[j])); break; case RMSPROP: s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * ((float)g_vals[j]) * ((float)g_vals[j])); p_vals[j] = ((float)p_vals[j]) - update_scale * (lr * __fdividef((float)g_vals[j], sqrtf((float)s1_vals[j]) + eps)); break; case ADAGRAD: s1_vals[j] = s1_vals[j] + ((float)g_vals[j]) * ((float)g_vals[j]); p_vals[j] = ((float)p_vals[j]) - lr * __fdividef((float)g_vals[j], sqrtf((float)s1_vals[j]) + eps); break; } } } __syncthreads(); Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); __syncthreads(); StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); } } #define NUM8BIT 16 #define NUM_THREADS 256 #define NUM_PER_BLOCK 4096 template __global__ void __launch_bounds__(NUM_THREADS, 2) kPreconditionOptimizerStatic8bit2State( T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, unsigned char* __restrict__ const state2, float* unorm, const float beta1, const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, const float gnorm_scale, const int n ) { const int n_full = gridDim.x * NUM_PER_BLOCK; const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); int valid_items = n - (blockIdx.x * NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x * NUM_PER_BLOCK); float g_val = 0.0f; float local_max_s1 = -FLT_MAX; float local_max_s2 = -FLT_MAX; float local_unorm = 0.0f; float s2_vals[NUM8BIT]; float s1_vals[NUM8BIT]; T g_vals[NUM8BIT]; unsigned char m_c1[NUM8BIT]; unsigned char r_c2[NUM8BIT]; typedef cub::BlockLoad LoadT; typedef cub::BlockLoad LoadUInt8; typedef cub::BlockReduce BlockReduce; __shared__ union { typename LoadT::TempStorage loadh; typename LoadUInt8::TempStorage loadc; typename BlockReduce::TempStorage reduce; } temp_storage; __shared__ float smem_quantiles1[256]; __shared__ float smem_quantiles2[256]; if (threadIdx.x < 256) { smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; smem_quantiles2[threadIdx.x] = quantiles2[threadIdx.x]; } __syncthreads(); for (unsigned int i = base_idx; i < n_full; i += NUM_THREADS * gridDim.x * NUM8BIT) { valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i; LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); __syncthreads(); LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); __syncthreads(); LoadUInt8(temp_storage.loadc).Load(&(state2[i]), r_c2, valid_items, 128); __syncthreads(); #pragma unroll 16 for (int j = 0; j < NUM8BIT; j++) { g_val = g_vals[j]; g_val *= gnorm_scale; s1_vals[j] = smem_quantiles1[m_c1[j]] * max1[0] * beta1; s1_vals[j] += (1.0f - beta1) * g_val; local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j])); } #pragma unroll 16 for (int j = 0; j < NUM8BIT; j++) { g_val = g_vals[j]; g_val *= gnorm_scale; s2_vals[j] = smem_quantiles2[r_c2[j]] * max2[0] * beta2; s2_vals[j] += (1.0f - beta2) * g_val * g_val; local_max_s2 = fmaxf(local_max_s2, fabsf(s2_vals[j])); } if (unorm != NULL) { #pragma unroll 16 for (int j = 0; j < NUM8BIT; j++) { float correction1 = __fdividef(1.0f, 1.0f - powf(beta1, step)); float correction2 = __fdividef(1.0f, 1.0f - powf(beta2, step)); s1_vals[j] *= correction1; s2_vals[j] *= correction2; float update_val = s1_vals[j] / (sqrtf(s2_vals[j]) + eps); // update local_unorm += update_val * update_val; } } } __syncthreads(); local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, CUB_REDUCTIONOP_MAX, valid_items); __syncthreads(); local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, CUB_REDUCTIONOP_MAX, valid_items); if (unorm != NULL) { __syncthreads(); local_unorm = BlockReduce(temp_storage.reduce).Sum(local_unorm, valid_items); } if (threadIdx.x == 0) { atomicMax(&new_max1[0], local_max_s1); atomicMax(&new_max2[0], local_max_s2); if (unorm != NULL) { atomicAdd(&unorm[0], local_unorm); } } } #define NUM_PER_THREAD2 4 #define NUM_THREADS2 1024 #define NUM_PER_BLOCK2 4096 template __global__ void __launch_bounds__(NUM_THREADS2, 1) kOptimizerStatic8bit2State( T* p, T* const g, unsigned char* state1, unsigned char* state2, const float* unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, const int n ) { const int n_full = (blockDim.x * gridDim.x) * NUM_PER_THREAD2; const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2); int valid_items = 0; float g_val = 0.0f; float s1_vals[NUM_PER_THREAD2]; float s2_vals[NUM_PER_THREAD2]; const float correction1 = 1.0f - powf(beta1, step); const float correction2 = sqrtf(1.0f - powf(beta2, step)); const float step_size = -lr * correction2 / correction1; // const float step_size = -lr*correction2/correction1; float new_max_val1 = 1.0f / new_max1[0]; float new_max_val2 = 1.0f / new_max2[0]; float update_scale = 1.0f; if (max_unorm > 0.0f) { update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; if (update_scale > max_unorm * param_norm) { update_scale = (max_unorm * param_norm) / update_scale; } else { update_scale = 1.0f; } } else { update_scale = 1.0f; } unsigned char c1s[NUM_PER_THREAD2]; unsigned char c2s[NUM_PER_THREAD2]; T p_vals[NUM_PER_THREAD2]; T g_vals[NUM_PER_THREAD2]; typedef cub::BlockLoad LoadT; typedef cub::BlockLoad LoadChar; typedef cub::BlockStore StoreChar; typedef cub::BlockStore StoreT; __shared__ float smem_quantiles1[256]; __shared__ float smem_quantiles2[256]; __shared__ union { typename LoadT::TempStorage loadh; typename LoadChar::TempStorage loadc; typename StoreChar::TempStorage storec; typename StoreT::TempStorage storeh; } temp_storage; if (threadIdx.x < 512) { if (threadIdx.x < 256) smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; else smem_quantiles2[threadIdx.x - 256] = quantiles2[threadIdx.x - 256]; } __syncthreads(); for (unsigned int i = base_idx; i < n_full; i += gridDim.x * NUM_THREADS2 * NUM_PER_THREAD2) { valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i; LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); __syncthreads(); LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); __syncthreads(); LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); __syncthreads(); LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); if ((i + (threadIdx.x * NUM_PER_THREAD2) + NUM_PER_THREAD2) > n) { continue; } #pragma unroll 4 for (unsigned int j = 0; j < NUM_PER_THREAD2; j++) { g_val = float(g_vals[j]); g_val *= gnorm_scale; s1_vals[j] = smem_quantiles1[c1s[j]]; s1_vals[j] = s1_vals[j] * max1[0]; s1_vals[j] = (s1_vals[j] * beta1) + (((1.0f - beta1) * g_val)); c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j] * new_max_val1); // make sure state1 term has still the same sign after quantization // (not needed for state2 term which has only positive values) if (signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) { if (s1_vals[j] > 0.0f) c1s[j] += 1; else c1s[j] -= 1; } s2_vals[j] = smem_quantiles2[c2s[j]]; s2_vals[j] = s2_vals[j] * max2[0]; s2_vals[j] = (s2_vals[j] * beta2) + (((1.0f - beta2) * g_val * g_val)); c2s[j] = dQuantize<0>(smem_quantiles2, 0.0f, s2_vals[j] * new_max_val2); } #pragma unroll 4 for (unsigned int j = 0; j < NUM_PER_THREAD2; j++) { p_vals[j] = (T)(((float)p_vals[j]) + ((update_scale * step_size * (s1_vals[j] / (sqrtf(s2_vals[j]) + (correction2 * eps)))))); if (weight_decay > 0.0f) p_vals[j] = update_scale * ((float)p_vals[j]) * (1.0f - (lr * weight_decay)); } StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); __syncthreads(); StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); __syncthreads(); StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); __syncthreads(); } } template __global__ void __launch_bounds__(NUM_THREADS, 2) kPreconditionOptimizerStatic8bit1State( T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, float* unorm, const float beta1, const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, float* max1, float* new_max1, const float weight_decay, const float gnorm_scale, const int n ) { const int n_full = gridDim.x * NUM_PER_BLOCK; const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); int valid_items = n - (blockIdx.x * NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x * NUM_PER_BLOCK); float g_val = 0.0f; float local_max_s1 = -FLT_MAX; float local_unorm = 0.0f; float s1_vals[NUM8BIT]; T g_vals[NUM8BIT]; unsigned char m_c1[NUM8BIT]; typedef cub::BlockLoad LoadT; typedef cub::BlockLoad LoadUInt8; typedef cub::BlockReduce BlockReduce; __shared__ union { typename LoadT::TempStorage loadh; typename LoadUInt8::TempStorage loadc; typename BlockReduce::TempStorage reduce; } temp_storage; __shared__ float smem_quantiles1[256]; if (threadIdx.x < 256) smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; __syncthreads(); for (unsigned int i = base_idx; i < n_full; i += gridDim.x * NUM_THREADS * NUM8BIT) { valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i; __syncthreads(); LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); __syncthreads(); LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); #pragma unroll 16 for (int j = 0; j < NUM8BIT; j++) { g_val = g_vals[j]; g_val *= gnorm_scale; s1_vals[j] = smem_quantiles1[m_c1[j]] * max1[0]; switch (OPTIMIZER) { case ADAGRAD: case MOMENTUM: if (step == 1) s1_vals[j] = (float)g_vals[j]; else s1_vals[j] = s1_vals[j] * beta1 + ((float)g_vals[j]); if (unorm != NULL) local_unorm += s1_vals[j] * s1_vals[j]; break; case LION: s1_vals[j] = s1_vals[j] * beta2 + ((1.0f - beta2) * g_val); break; case RMSPROP: s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * (g_val * g_val)); break; } local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j])); } } __syncthreads(); local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, CUB_REDUCTIONOP_MAX, valid_items); if (threadIdx.x == 0) { atomicMax(&new_max1[0], local_max_s1); } if (unorm != NULL) { __syncthreads(); local_unorm = BlockReduce(temp_storage.reduce).Sum(local_unorm, valid_items); if (threadIdx.x == 0) { atomicAdd(&unorm[0], local_unorm); } } } template __global__ void __launch_bounds__(1024, 1) kOptimizerStatic8bit1State( T* p, T* const g, unsigned char* state1, const float* unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* max1, float* new_max1, float weight_decay, const float gnorm_scale, const int n ) { const int n_full = (blockDim.x * gridDim.x) * NUM_PER_THREAD2; const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2); int valid_items = 0; float g_val = 0.0f; float s1_vals[NUM_PER_THREAD2]; float new_max_val1 = 1.0f / new_max1[0]; float update_scale = 1.0f; if (max_unorm > 0.0f) { update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; if (update_scale > max_unorm * param_norm) { update_scale = (max_unorm * param_norm) / update_scale; } else { update_scale = 1.0f; } } else { update_scale = 1.0f; } unsigned char c1s[NUM_PER_THREAD2]; T p_vals[NUM_PER_THREAD2]; T g_vals[NUM_PER_THREAD2]; typedef cub::BlockLoad LoadT; typedef cub::BlockLoad LoadChar; typedef cub::BlockStore StoreChar; typedef cub::BlockStore StoreT; __shared__ float smem_quantiles1[256]; __shared__ union { typename LoadT::TempStorage loadh; typename LoadChar::TempStorage loadc; typename StoreChar::TempStorage storec; typename StoreT::TempStorage storeh; } temp_storage; if (threadIdx.x < 256) smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; __syncthreads(); for (unsigned int i = base_idx; i < n_full; i += gridDim.x * NUM_THREADS2 * NUM_PER_THREAD2) { valid_items = n - i >= (TH * NUM_PER_THREAD) ? (TH * NUM_PER_THREAD) : n - i; LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); __syncthreads(); LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); __syncthreads(); LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); if ((i + (threadIdx.x * NUM_PER_THREAD2) + NUM_PER_THREAD2) > n) { continue; } #pragma unroll 4 for (unsigned int j = 0; j < NUM_PER_THREAD2; j++) { g_val = float(g_vals[j]); g_val *= gnorm_scale; if (weight_decay > 0.0f) { switch (OPTIMIZER) { case ADAGRAD: case MOMENTUM: case RMSPROP: g_val += ((float)p_vals[j]) * weight_decay; break; case LION: p_vals[j] = ((float)p_vals[j]) * (1.0f - lr * weight_decay); break; } } s1_vals[j] = smem_quantiles1[c1s[j]] * max1[0]; switch (OPTIMIZER) { case ADAGRAD: case MOMENTUM: if (step == 1) s1_vals[j] = g_vals[j]; else s1_vals[j] = s1_vals[j] * beta1 + ((float)g_vals[j]); p_vals[j] = ((float)p_vals[j]) + (-lr * update_scale * (s1_vals[j])); break; case LION: p_vals[j] = ((float)p_vals[j]) - (lr * sgn(((float)s1_vals[j]) * beta1 + ((1.0f - beta1) * ((float)g_val)))); s1_vals[j] = s1_vals[j] * beta2 + ((1.0f - beta2) * g_val); break; case RMSPROP: s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * (g_val * g_val)); p_vals[j] = ((float)p_vals[j]) - (lr * __fdividef(g_val, sqrtf(s1_vals[j]) + eps)); break; } c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j] * new_max_val1); // make sure state1 term has still the same sign after quantization if (signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) { if (s1_vals[j] > 0.0f) c1s[j] += 1; else c1s[j] -= 1; } } StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); __syncthreads(); StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); __syncthreads(); } } template __global__ void kPercentileClipping(T* __restrict__ g, float* gnorm_vec, int step, const int n) { const int n_full = (BLOCK_SIZE * (n / BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); int valid_items = 0; typedef cub::BlockReduce BlockReduce; typedef cub::BlockLoad LoadT; __shared__ typename BlockReduce::TempStorage reduce; __shared__ typename LoadT::TempStorage loadT; T vals[NUM_VALS]; float local_sum = 0.0f; for (unsigned int i = (blockIdx.x * BLOCK_SIZE); i < n_full; i += gridDim.x * BLOCK_SIZE) { valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; local_sum = 0.0f; __syncthreads(); LoadT(loadT).Load(&(g[i]), vals, valid_items, (T)0.0f); #pragma unroll NUM_VALS for (int j = 0; j < NUM_VALS; j++) local_sum += ((float)vals[j]) * ((float)vals[j]); local_sum = BlockReduce(reduce).Sum(local_sum, valid_items); if (threadIdx.x == 0) { if (step == 1) { // initialize with the same norm for all positions // #pragma unroll 10 for (int j = 0; j < 100; j++) atomicAdd(&gnorm_vec[j], local_sum); } else atomicAdd(&gnorm_vec[step % 100], local_sum); } } } #define LANES 2 #define QUAD 3 template __launch_bounds__(256, 3) __global__ void kOptimizerStatic8bit2StateBlockwise( T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n ) { // const int n_full = n + (n%BLOCK_SIZE); const int n_full = gridDim.x * BLOCK_SIZE; const int base_idx = (blockIdx.x * BLOCK_SIZE); int valid_items = 0; float g_val = 0.0f; float s1_vals[N_PER_TH]; float s2_vals[N_PER_TH]; float s3_vals[N_PER_TH]; // 2-5% const float correction1 = 1.0f - __powf(beta1, step); const float correction2 = sqrtf(1.0f - __powf(beta2, step)); const float step_size = __fdividef(-lr * correction2, correction1); const int lane_id = threadIdx.x % LANES; float new_local_abs_max1 = -FLT_MAX; float new_local_abs_max2 = -FLT_MAX; float new_local_abs_max3 = -FLT_MAX; float quadrants1[QUAD]; float quadrants2[QUAD]; unsigned char c1s[N_PER_TH]; unsigned char c2s[N_PER_TH]; unsigned char c3s[N_PER_TH]; T g_vals[N_PER_TH]; T p_vals[N_PER_TH]; typedef cub::BlockLoad LoadT; typedef cub::BlockLoad LoadChar; typedef cub::BlockStore StoreChar; typedef cub::BlockStore StoreT; __shared__ float smem_quantiles1[LANES][257]; __shared__ float smem_quantiles2[LANES][257]; typedef cub::BlockReduce BlockReduce1; typedef cub::BlockReduce BlockReduce2; typedef cub::BlockReduce BlockReduce3; __shared__ typename BlockReduce1::TempStorage reduce1; __shared__ typename BlockReduce2::TempStorage reduce2; __shared__ typename BlockReduce2::TempStorage reduce3; __shared__ float smem_exchange1[1]; __shared__ float smem_exchange2[1]; __shared__ float smem_exchange3[1]; // [[maybe_unused]] __shared__ union { typename LoadT::TempStorage loadh; typename LoadChar::TempStorage loadc; typename StoreChar::TempStorage storec; typename StoreT::TempStorage storeh; } temp_storage; // init: 0.2 -> 0.23 // 0.23 -> 0.23 smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x]; smem_quantiles2[0][threadIdx.x] = quantiles2[threadIdx.x]; #pragma unroll for (unsigned int j = 1; j < LANES; j++) { smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x]; smem_quantiles2[j][threadIdx.x] = smem_quantiles2[0][threadIdx.x]; } __syncthreads(); #pragma unroll for (int k = 0; k < QUAD; k++) { quadrants1[k] = smem_quantiles1[lane_id][(k * 256 / (QUAD + 1)) + (256 / (QUAD + 1) - 1)]; quadrants2[k] = smem_quantiles2[lane_id][(k * 256 / (QUAD + 1)) + (256 / (QUAD + 1) - 1)]; } for (unsigned int i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) { // loads: 0.23 -> 0.85/1.44 valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; __syncthreads(); LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); __syncthreads(); LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); __syncthreads(); LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); // AdEMAMix has an additional state packed into state1. if (OPTIMIZER == ADEMAMIX) { __syncthreads(); LoadChar(temp_storage.loadc).Load(&(state1[n + i]), c3s, valid_items, 128); } new_local_abs_max1 = -FLT_MAX; new_local_abs_max2 = -FLT_MAX; new_local_abs_max3 = -FLT_MAX; // update: 2.48/1.57 -> 2.51/1.60 #pragma unroll N_PER_TH for (unsigned int j = 0; j < N_PER_TH; j++) { if (!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) { s2_vals[j] = smem_quantiles2[lane_id][c2s[j]] * absmax2[i / BLOCK_SIZE]; g_val = g_vals[j]; // float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps); // g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val; g_val *= gnorm_scale; s2_vals[j] = (s2_vals[j] * beta2) + (((1.0f - beta2) * g_val * g_val)); s1_vals[j] = smem_quantiles1[lane_id][c1s[j]] * absmax1[i / BLOCK_SIZE]; s1_vals[j] = (s1_vals[j] * beta1) + (((1.0f - beta1) * g_val)); if (OPTIMIZER == ADEMAMIX) { // The absmax for the third state is appended to absmax1 s3_vals[j] = smem_quantiles1[lane_id][c3s[j]] * absmax1[(n + i) / BLOCK_SIZE]; s3_vals[j] = (s3_vals[j] * beta3) + (((1.0f - beta3) * g_val)); } } else { s1_vals[j] = 0.0f; s2_vals[j] = 0.0f; if (OPTIMIZER == ADEMAMIX) { s3_vals[j] = 0.0f; } } new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j])); if (OPTIMIZER == ADEMAMIX) { new_local_abs_max3 = fmaxf(new_local_abs_max3, fabsf(s3_vals[j])); } } // reduce: 2.51/1.60 -> 2.67/1.69 new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, CUB_REDUCTIONOP_MAX); new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, CUB_REDUCTIONOP_MAX); if (OPTIMIZER == ADEMAMIX) { new_local_abs_max3 = BlockReduce3(reduce3).Reduce(new_local_abs_max3, CUB_REDUCTIONOP_MAX); } if (threadIdx.x == 0) { smem_exchange1[0] = new_local_abs_max1; smem_exchange2[0] = new_local_abs_max2; if (OPTIMIZER == ADEMAMIX) { smem_exchange3[0] = new_local_abs_max3; } } __syncthreads(); if (threadIdx.x == 0) { absmax1[i / BLOCK_SIZE] = new_local_abs_max1; absmax2[i / BLOCK_SIZE] = new_local_abs_max2; if (OPTIMIZER == ADEMAMIX) { absmax1[(n + i) / BLOCK_SIZE] = new_local_abs_max3; } } else { new_local_abs_max1 = smem_exchange1[0]; new_local_abs_max2 = smem_exchange2[0]; if (OPTIMIZER == ADEMAMIX) { new_local_abs_max3 = smem_exchange3[0]; } } __syncthreads(); LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); // reduce: 2.67/1.69 -> 2.67/1.70 #pragma unroll N_PER_TH for (unsigned int j = 0; j < N_PER_TH; j++) { // if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) if (!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) { if (OPTIMIZER == ADEMAMIX) { p_vals[j] = T((float)p_vals[j] - lr * (((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / ((sqrtf(s2_vals[j]) / correction2) + eps))); } else { p_vals[j] = (T)(((float)p_vals[j]) + ((step_size * (__fdividef(s1_vals[j], (sqrtf(s2_vals[j]) + (correction2 * eps))))))); } if (weight_decay > 0.0f) p_vals[j] = ((float)p_vals[j]) * (1.0f - (lr * weight_decay)); } } // store: 0.85/1.44 -> 2.48/1.57 __syncthreads(); StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); // quantizaztion: 2.67/1.70 -> 3.4/3.3 #pragma unroll N_PER_TH for (unsigned int j = 0; j < N_PER_TH; j++) { c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j], new_local_abs_max1)); c2s[j] = quantize_2D<0>(quadrants2, smem_quantiles2[lane_id], __fdividef(s2_vals[j], new_local_abs_max2)); // make sure state1 term has still the same sign after quantization // (not needed for state2 term which has only positive values) if (signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) { if (s1_vals[j] > 0.0f) c1s[j] += 1; else c1s[j] -= 1; } if (OPTIMIZER == ADEMAMIX) { c3s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s3_vals[j], new_local_abs_max3)); if (signbit(smem_quantiles1[lane_id][c3s[j]]) != signbit(s3_vals[j])) { c3s[j] += (s3_vals[j] > 0.0f) ? 1 : -1; } } } __syncthreads(); StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); __syncthreads(); StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); if (OPTIMIZER == ADEMAMIX) { __syncthreads(); StoreChar(temp_storage.storec).Store(&(state1[n + i]), c3s, valid_items); } } } #define LANES 2 #define QUAD 3 template __launch_bounds__(256, 3) __global__ void kOptimizerStatic8bit1StateBlockwise( T* p, T* __restrict__ const g, unsigned char* state1, const float beta1, const float beta2, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* absmax1, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n ) { // const int n_full = n + (n%BLOCK_SIZE); const int n_full = gridDim.x * BLOCK_SIZE; const int base_idx = (blockIdx.x * BLOCK_SIZE); int valid_items = 0; float g_val = 0.0f; float s1_vals[N_PER_TH]; // 2-5% const int lane_id = threadIdx.x % LANES; float new_local_abs_max1 = -FLT_MAX; float quadrants1[QUAD]; unsigned char c1s[N_PER_TH]; T g_vals[N_PER_TH]; T p_vals[N_PER_TH]; typedef cub::BlockLoad LoadT; typedef cub::BlockLoad LoadChar; typedef cub::BlockStore StoreChar; typedef cub::BlockStore StoreT; __shared__ float smem_quantiles1[LANES][257]; typedef cub::BlockReduce BlockReduce1; __shared__ typename BlockReduce1::TempStorage reduce1; __shared__ float smem_exchange1[1]; __shared__ union { typename LoadT::TempStorage loadh; typename LoadChar::TempStorage loadc; typename StoreChar::TempStorage storec; typename StoreT::TempStorage storeh; } temp_storage; // init: 0.2 -> 0.23 // 0.23 -> 0.23 smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x]; #pragma unroll for (unsigned int j = 1; j < LANES; j++) smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x]; __syncthreads(); #pragma unroll for (int k = 0; k < QUAD; k++) quadrants1[k] = smem_quantiles1[lane_id][(k * 256 / (QUAD + 1)) + (256 / (QUAD + 1) - 1)]; for (unsigned int i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) { // loads: 0.23 -> 0.85/1.44 valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; __syncthreads(); LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); __syncthreads(); LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); __syncthreads(); LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); new_local_abs_max1 = -FLT_MAX; // update: 2.48/1.57 -> 2.51/1.60 #pragma unroll N_PER_TH for (unsigned int j = 0; j < N_PER_TH; j++) { g_val = float(g_vals[j]); g_val *= gnorm_scale; if (!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) { if (weight_decay > 0.0f) { switch (OPTIMIZER) { case MOMENTUM: case ADAGRAD: case RMSPROP: g_val += ((float)p_vals[j]) * weight_decay; break; case LION: p_vals[j] = ((float)p_vals[j]) * (1.0f - lr * weight_decay); break; } } s1_vals[j] = smem_quantiles1[lane_id][c1s[j]] * absmax1[i / BLOCK_SIZE]; switch (OPTIMIZER) { case MOMENTUM: if (step == 1) s1_vals[j] = g_val; else s1_vals[j] = (s1_vals[j] * beta1) + g_val; break; case LION: // here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update, // before the momentum is updated by beta2 g_vals[j] = lr * sgn(((float)s1_vals[j]) * beta1 + ((1.0f - beta1) * g_val)); s1_vals[j] = s1_vals[j] * beta2 + ((1.0f - beta2) * g_val); break; case RMSPROP: s1_vals[j] = s1_vals[j] * beta1 + ((1.0f - beta1) * (g_val * g_val)); break; case ADAGRAD: s1_vals[j] = s1_vals[j] + (g_val * g_val); break; } } new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); } // reduce: 2.51/1.60 -> 2.67/1.69 new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, CUB_REDUCTIONOP_MAX); if (threadIdx.x == 0) smem_exchange1[0] = new_local_abs_max1; __syncthreads(); if (threadIdx.x == 0) absmax1[i / BLOCK_SIZE] = new_local_abs_max1; else new_local_abs_max1 = smem_exchange1[0]; // reduce: 2.67/1.69 -> 2.67/1.70 #pragma unroll N_PER_TH for (unsigned int j = 0; j < N_PER_TH; j++) { if (!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) { switch (OPTIMIZER) { case MOMENTUM: p_vals[j] = ((float)p_vals[j]) - lr * (s1_vals[j]); break; case LION: p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]); break; case RMSPROP: g_val = g_vals[j]; p_vals[j] = ((float)p_vals[j]) - lr * (__fdividef(g_val, sqrtf(s1_vals[j]) + eps)); break; case ADAGRAD: g_val = g_vals[j]; p_vals[j] = ((float)p_vals[j]) - lr * (__fdividef(g_val, sqrtf(s1_vals[j]) + eps)); break; } } } // store: 0.85/1.44 -> 2.48/1.57 __syncthreads(); StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); // quantizaztion: 2.67/1.70 -> 3.4/3.3 #pragma unroll N_PER_TH for (unsigned int j = 0; j < N_PER_TH; j++) { c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j], new_local_abs_max1)); // make sure state1 term has still the same sign after quantization // (not needed for state2 term which has only positive values) if (signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) { if (s1_vals[j] > 0.0f) c1s[j] += 1; else c1s[j] -= 1; } } __syncthreads(); StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); } } // Inputs: // A [rows, cols] // Outputs: // rowStats [rows] // out [rows, cols] template __launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__ void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) { // For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32. // Otherwise `T` is `fp16`. This can be removed when Maxwell is dropped. #if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR >= 2) || BNB_FP16_AVAILABLE using TReduction = T; #else using TReduction = float; #endif using BlockReduceT = cub::BlockReduce; // One block per row. // Threads load column values in a striped arrangement. // e.g. t0 reads row[0], row[0+nthreads], .. // and t1 reads row[1], row[1+nthreads], .. // Each thread will determine its local absmax. // We then do a blockwise reduction to determine the row's absmax. __shared__ typename BlockReduceT::TempStorage temp_storage; __shared__ TReduction smem_row_absmax; const int row_id = blockIdx.x; const T* row_data = A + (row_id * cols); // Threads will read the row values in a striped access pattern and find a local absmax. TReduction row_local_absmax = -FLT_MIN; for (int i = threadIdx.x; i < cols; i += THREADS) { const TReduction absval = fabsf(__ldcs(&(row_data[i]))); // For sparse decomposition, values outside of the threshold are not to be // included when calculating the row's absmax. if constexpr (SPARSE_DECOMP) { row_local_absmax = fmaxf(row_local_absmax, absval < TReduction(threshold) ? absval : row_local_absmax); } else { row_local_absmax = fmaxf(row_local_absmax, absval); } } // Reduce thread-local absmax across the block. const TReduction row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, CUB_REDUCTIONOP_MAX, cols); if (threadIdx.x == 0) { // Save our block's absmax to shared memory for the quantization step. rowStats[row_id] = smem_row_absmax = row_absmax; } __syncthreads(); // Quantize row-wise. const float scale = __fdividef(127.0f, smem_row_absmax); for (int i = threadIdx.x; i < cols; i += THREADS) { float val = row_data[i]; if constexpr (SPARSE_DECOMP) { // For sparse decomposition, we do not want to quantize the outliers. // Instead they're zeroed out. out[row_id * cols + i] = fabs(val) < threshold ? __float2int_rn(val * scale) : 0; } else { out[row_id * cols + i] = __float2int_rn(val * scale); } } } template __launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__ void kgetRowStats(T* __restrict__ A, float* rowStats, float threshold, int rows, int cols) { using BlockReduceT = cub::BlockReduce; // One block per row. // Threads load column values in a striped arrangement. // e.g. t0 reads row[0], row[0+nthreads], .. // and t1 reads row[1], row[1+nthreads], .. // Each thread will determine its local absmax. // We then do a blockwise reduction to determine the row's absmax. __shared__ typename BlockReduceT::TempStorage temp_storage; const int row_id = blockIdx.x; const T* __restrict__ row_data = A + (row_id * cols); // Threads will read the row values in a striped access pattern and find a local absmax. float row_local_absmax = -FLT_MIN; for (int i = threadIdx.x; i < cols; i += THREADS) { const float absval = fabsf(row_data[i]); // For sparse decomposition, values outside of the threshold are not to be // included when calculating the row's absmax. if constexpr (SPARSE_DECOMP) { row_local_absmax = fmaxf(row_local_absmax, absval < threshold ? absval : row_local_absmax); } else { row_local_absmax = fmaxf(row_local_absmax, absval); } } // Reduce thread-local absmax across the block. // TODO: Consider algorithm BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, CUB_REDUCTIONOP_MAX, cols); if (threadIdx.x == 0) { // Save our block's absmax to shared memory for the quantization step. rowStats[row_id] = row_absmax; } } template __global__ void kgetRowStats(half* __restrict__ A, float* rowStats, float threshold, int rows, int cols); template __global__ void kgetRowStats(half* __restrict__ A, float* rowStats, float threshold, int rows, int cols); template __global__ void kInt8VectorQuant( half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols ); template __global__ void kInt8VectorQuant( half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols ); #define MM_DEQUANT_CONST 6.200012e-05f // 1.0f/(127.0f*127.0f) template __global__ void kdequant_mm_int32_fp16( int* __restrict__ const A, float* __restrict__ const rowStats, float* __restrict__ const colStats, half* out, half* __restrict__ const bias, const int numRows, const int numCols, const int n ) { const int n_out = numRows * numCols; int block_offset = blockIdx.x * THREADS * ITEMS_PER_THREAD; int thread_offset = threadIdx.x * ITEMS_PER_THREAD; int local_values[ITEMS_PER_THREAD]; half local_output[ITEMS_PER_THREAD]; float local_rowStats[ITEMS_PER_THREAD]; float local_colStats[ITEMS_PER_THREAD]; float local_biasValue[ITEMS_PER_THREAD]; typedef cub::BlockLoad LoadInt32; __shared__ typename LoadInt32::TempStorage loadint32; int row_idx, col_idx; #pragma unroll ITEMS_PER_THREAD for (int j = 0; j < ITEMS_PER_THREAD; ++j) { row_idx = (block_offset + thread_offset + j) / numCols; col_idx = (block_offset + thread_offset + j) % numCols; local_colStats[j] = col_idx >= numCols ? 0.0f : __ldg(&colStats[col_idx]); local_rowStats[j] = row_idx >= numRows ? 0.0f : __ldg(&rowStats[row_idx]); local_biasValue[j] = ((bias == nullptr) || col_idx >= numCols) ? 0.0f : __half2float(bias[col_idx]); } // Each block loads THREADS * ITEMS_PER_THREAD values from A int valid_items = block_offset + THREADS * ITEMS_PER_THREAD < n_out ? THREADS * ITEMS_PER_THREAD : n_out - block_offset; LoadInt32(loadint32).Load(&(A[block_offset]), local_values, valid_items, 0); #pragma unroll ITEMS_PER_THREAD for (int j = 0; j < ITEMS_PER_THREAD; ++j) { local_output[j] = __float2half( fmaf(local_values[j] * local_rowStats[j] * local_colStats[j], MM_DEQUANT_CONST, local_biasValue[j]) ); } #pragma unroll ITEMS_PER_THREAD for (int j = 0; j < ITEMS_PER_THREAD; j++) { int outIdx = block_offset + thread_offset + j; if (outIdx < n_out) { out[outIdx] = local_output[j]; } } } #define DENORM 1.0f / 127.0f #define MAX_SPARSE_COUNT 32 #define SMEM_SIZE 8 * 256 template __global__ void kspmm_coo_very_sparse_naive( int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out, float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB ) { // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block // If a block finishes, the next one is scheduled. Since the last blocks like have fewer // elements they finish faster "fillin up" the gaps left by larger blocks // without tensor cores // 1. use rowidx_length to find what to load (as many blocks as there are rows) // 2. Load A into registers // 3. each warp loads all required rows of B but each warp is offset by k // 4. Do mma operations that accumulate into registers // 5. Each warp stores its output row into matrix C const int count = max_count[blockIdx.x]; const int local_max_idx = max_idx[blockIdx.x]; const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx - 1]; const int local_row_idx = rowidx[offset]; const int warp_id = threadIdx.x / 32; const int warp_idx = threadIdx.x % 32; const int warp_offset = (warp_id * 32) * SPMM_ITEMS; const int num_items = BITS == 8 ? 8 : 8; int idx_col_B = warp_offset; int local_idx_col_B_offset = 0; half local_valA[MAX_SPARSE_COUNT]; int local_colidxA[MAX_SPARSE_COUNT]; half local_valC[SPMM_ITEMS]; T local_valsB[num_items]; half local_valOut[num_items]; // 128 byte loads per warp == 4 bytes per thread // 2. Load A into registers for (int j = 0; j < MAX_SPARSE_COUNT; j++) { local_valA[j] = j < count ? values[offset + j] : __float2half(0.0f); local_colidxA[j] = j < count ? colidx[offset + j] : 0; } // each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192 // we expect each warp to be SPMM_ITEMS*32 apart // we have a total of 128 bytes for the bank with a bank size of 4 bytes // added 3 bytes = 6 values between warps should reduce bank conflicts __shared__ half smem_dequant_stats[SMEM_SIZE]; while (idx_col_B < colsB) { if (dequant_stats != NULL) { for (int i = threadIdx.x; i < SMEM_SIZE; i += blockDim.x) if ((idx_col_B + i - local_idx_col_B_offset) < colsB) smem_dequant_stats[i] = dequant_stats[idx_col_B + i - local_idx_col_B_offset]; __syncthreads(); } #pragma unroll SPMM_ITEMS for (int j = 0; j < SPMM_ITEMS; j++) local_valC[j] = 0.0f; #pragma unroll for (int i = 0; i < count; i++) { // 3. each warp loads all required rows of B but each warp is offset by k int row_offset = colsB * local_colidxA[i]; #pragma unroll SPMM_ITEMS for (int j = 0; j < SPMM_ITEMS; j += num_items) { // 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached int idx = idx_col_B + (warp_idx * SPMM_ITEMS) + j; if (idx >= colsB) { break; } if ((idx + num_items < colsB)) { if (BITS == 8) reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset + idx) / num_items]; else reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset + idx) / num_items]; } else { #pragma unroll num_items for (int k = 0; k < num_items; k++) if (idx + k < colsB) local_valsB[k] = B[row_offset + idx + k]; else local_valsB[k] = 0.0f; } #pragma unroll num_items for (int k = 0; k < num_items; k++) { if (BITS == 8 && dequant_stats != NULL) // we do texture cache reads (__ldg) on dequant_stats which should be super fast { float valB = local_valsB[k]; float valA = local_valA[i]; if (valB != 0.0 && valA != 0.0) local_valC[j + k] = (float)local_valC[j + k] + ((float)smem_dequant_stats[idx + k - local_idx_col_B_offset]) * DENORM * valB * valA; } else local_valC[j + k] = (float)local_valC[j + k] + (float)local_valsB[k] * (float)local_valA[i]; } } } int idx_row_C = (colsB * local_row_idx); #pragma unroll SPMM_ITEMS for (int j = 0; j < SPMM_ITEMS; j += num_items) { // int idx_col_C = idx_col_B + (32*j) + warp_idx; int idx_col_C = idx_col_B + warp_idx * SPMM_ITEMS + j; int idx_val = idx_col_C + idx_row_C; if (idx_col_C + num_items < colsB) { // load outputs to do inplace addition reinterpret_cast(local_valOut)[0] = reinterpret_cast(out)[idx_val / num_items]; #pragma unroll num_items for (int k = 0; k < num_items; k++) local_valC[(j / num_items) + k] = (float)local_valC[(j / num_items) + k] + (float)local_valOut[k]; reinterpret_cast(out)[idx_val / num_items] = reinterpret_cast(local_valC)[j / num_items]; } else { #pragma unroll num_items for (int k = 0; k < num_items; k++) if (idx_col_C + k < colsB) out[idx_val + k] = (float)out[idx_val + k] + (float)local_valC[j + k]; } } idx_col_B += blockDim.x * SPMM_ITEMS; local_idx_col_B_offset += blockDim.x * SPMM_ITEMS; } } #define WARPS 3 template __global__ void gemm_device(int M, int N, int K, T* __restrict__ const A, T* B, T* out, int lda, int ldb, int ldc) { #if __CUDA_ARCH__ >= 750 using namespace nvcuda; int col_offset = blockIdx.x * 32; const int warp_id = threadIdx.x / 32; const int half_warp_id = threadIdx.x / 16; const int half_warp_lane = threadIdx.x % 16; const int batch_size_warps = (WARPS - 1) * 2; const int val_per_iter = blockDim.x - 32; T local_A[4]; T local_B[128]; const int a_tile_offset = 16; const int b_tile_offset = (16 * 32 + 16); __shared__ T smem_A[8 * 16 + (2 * 16 * (batch_size_warps - 1))]; __shared__ T smem_B[2 * batch_size_warps * 16 * 32 + (2 * 16 * (batch_size_warps - 1))]; //__shared__ T smem_C[8*32]; wmma::fragment a_frag; wmma::fragment b_frag; wmma::fragment c_frag; wmma::fill_fragment(c_frag, 0.0f); int ticktock = 0; int idx = 0 + threadIdx.x; int loaded_values = 0; // prefetch if (idx < K && warp_id < (WARPS - 1)) { if (loaded_values == 0) { local_A[0] = A[idx]; local_A[1] = A[idx + (1 * val_per_iter)]; local_A[2] = A[idx + (2 * val_per_iter)]; local_A[3] = A[idx + (3 * val_per_iter)]; #pragma unroll 32 for (int col = 0; col < 32; col++) { local_B[col] = B[(col_offset + col) * ldb + idx]; local_B[col + 32] = B[(col_offset + col) * ldb + idx + (1 * val_per_iter)]; local_B[col + 64] = B[(col_offset + col) * ldb + idx + (2 * val_per_iter)]; local_B[col + 96] = B[(col_offset + col) * ldb + idx + (3 * val_per_iter)]; } loaded_values = 3; } else { if (loaded_values == 3) { local_A[0] = local_A[1]; #pragma unroll 32 for (int col = 0; col < 32; col++) local_B[col] = local_B[col + (32)]; } else if (loaded_values == 2) { local_A[0] = local_A[2]; #pragma unroll 32 for (int col = 0; col < 32; col++) local_B[col] = local_B[col + (64)]; } else { local_A[0] = local_A[3]; #pragma unroll 32 for (int col = 0; col < 32; col++) local_B[col] = local_B[col + (96)]; } loaded_values--; } smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = local_A[0]; #pragma unroll 32 for (int col = 0; col < 32; col++) smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = local_B[col]; } else if (warp_id < (WARPS - 1)) { local_A[0] = T(0.0); smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = 0.0f; #pragma unroll 32 for (int col = 0; col < 32; col++) local_B[col] = 0.0f; #pragma unroll 32 for (int col = 0; col < 32; col++) smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = 0.0f; } ticktock = ticktock == 0 ? 1 : 0; // for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) for (int base_idx = blockDim.x - 32; base_idx < K; base_idx += blockDim.x - 32) { idx = base_idx + threadIdx.x; __syncthreads(); if (idx < K && warp_id < (WARPS - 1)) { // local_A[0] = A[idx]; // #pragma unroll 32 // for(int col = 0; col < 32; col++) // local_B[col] = B[(col_offset+col)*ldb+idx]; if (loaded_values == 0) { local_A[0] = A[idx]; local_A[1] = A[idx + (1 * val_per_iter)]; local_A[2] = A[idx + (2 * val_per_iter)]; local_A[3] = A[idx + (3 * val_per_iter)]; #pragma unroll 32 for (int col = 0; col < 32; col++) { local_B[col] = B[(col_offset + col) * ldb + idx]; local_B[col + 32] = B[(col_offset + col) * ldb + idx + (1 * val_per_iter)]; local_B[col + 64] = B[(col_offset + col) * ldb + idx + (2 * val_per_iter)]; local_B[col + 96] = B[(col_offset + col) * ldb + idx + (3 * val_per_iter)]; } loaded_values = 3; } else { if (loaded_values == 3) { local_A[0] = local_A[1]; #pragma unroll 32 for (int col = 0; col < 32; col++) local_B[col] = local_B[col + (32)]; } else if (loaded_values == 2) { local_A[0] = local_A[2]; #pragma unroll 32 for (int col = 0; col < 32; col++) local_B[col] = local_B[col + (64)]; } else { local_A[0] = local_A[3]; #pragma unroll 32 for (int col = 0; col < 32; col++) local_B[col] = local_B[col + (96)]; } loaded_values--; } smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = local_A[0]; #pragma unroll 32 for (int col = 0; col < 32; col++) smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = local_B[col]; } else if (warp_id < (WARPS - 1)) { local_A[0] = T(0.0); smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = 0.0f; #pragma unroll 32 for (int col = 0; col < 32; col++) local_B[col] = 0.0f; #pragma unroll 32 for (int col = 0; col < 32; col++) smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = 0.0f; } ticktock = ticktock == 0 ? 1 : 0; if (warp_id == (WARPS - 1)) for (int k = 0; k < batch_size_warps; k++) { wmma::load_matrix_sync( a_frag, &(smem_A[(ticktock * batch_size_warps + k) * a_tile_offset]), 16 ); // 111 mu wmma::load_matrix_sync( b_frag, &(smem_B[(ticktock * batch_size_warps + k) * b_tile_offset]), 16 ); // 35 mu wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); } } __syncthreads(); if (warp_id != (WARPS - 1)) { return; } // only warp_id == (WARPS-1) from here int warp_lane = threadIdx.x % 32; ticktock = ticktock == 0 ? 1 : 0; for (int k = 0; k < batch_size_warps; k++) { wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock * batch_size_warps + k) * a_tile_offset]), 16); // 111 mu wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock * batch_size_warps + k) * b_tile_offset]), 16); // 35 mu wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); } // 129 mu if (warp_id == (WARPS - 1)) wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); if (col_offset + warp_lane < M) out[col_offset + warp_lane] = smem_A[warp_lane]; #endif } template __device__ void printnonzero(T* A, int num_values, const char* strval) { for (int i = 0; i < num_values; i++) if ((float)A[i] != 0.0) printf("%s %i %f\n", strval, i, (float)A[i]); } template __global__ void kgemm_4bit_inference( int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, int blocksize ) { //// element-wise kernel //// 1. Load batch x k into registers //// 2. Load k x k into registers //// 3. dequantize and store in second pair of k x k //// 4. matmul //// 5. sum with cub //// 6. store outputs //// TC kernel //// use k warps per thread block //// 1. threadblock use read-only cache to read in register tile for A into shared memory //// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments //// 3. each warp reads a segment of values 16x32 from B //// 4. do dequantization from register of B into second pair of registers //// 5. store (4) into fragment //// 6. matmul aggregate into fragment C //// 7. aggregate files of C into shared memory block C //// 8. sum (7) //// 9. write outputs to matmul output matrix #if __CUDA_ARCH__ >= 750 using namespace nvcuda; int col_offset = blockIdx.x * 32; const int warp_id = threadIdx.x / 32; const int warp_idx = threadIdx.x % 32; const int half_warp_id = threadIdx.x / 16; const int half_warp_lane = threadIdx.x % 16; const int batch_size_warps = (WARPS - 1) * 2; T quant_map[16]; #pragma unroll 16 for (int i = 0; i < 16; i++) quant_map[i] = nf4_dequantization_lut[i]; //__shared__ T quant_map[16*160]; T local_A[2]; T local_B[64]; unsigned char local_B_4bit[32]; const int a_tile_offset = 16; const int b_tile_offset = (16 * 32 + 16); __shared__ T smem_A[8 * 16 + (16 * (batch_size_warps - 1))]; __shared__ T smem_B[2 * batch_size_warps * 16 * 32 + (2 * 16 * (batch_size_warps - 1))]; __shared__ T smem_C[8 * 32]; wmma::fragment a_frag; wmma::fragment b_frag; wmma::fragment c_frag; wmma::fill_fragment(c_frag, 0.0f); for (int i = threadIdx.x; i < (8 * 32); i += blockDim.x) smem_C[i] = 0.0f; __syncthreads(); int ticktock = 0; int idx = 0 + threadIdx.x; int loaded_values = 0; // prefetch if (idx < K && warp_id < (WARPS - 1)) { if (loaded_values == 0) { local_A[0] = A[idx]; local_A[1] = A[idx + blockDim.x - 32]; #pragma unroll 32 for (int col = 0; col < 32; col++) local_B_4bit[col] = B[(col_offset + col) * ldb + idx]; loaded_values = 1; } else { local_A[0] = local_A[1]; loaded_values--; #pragma unroll 64 for (int col = 0; col < 64; col += 2) { // local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f); // local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f); // local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0); // local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0); // local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0); // local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0); // local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0); // local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0); local_B[col] = quant_map[160 * (local_B_4bit[col / 2] >> 4) + warp_idx] * T(17.0); local_B[col + 1] = quant_map[160 * (local_B_4bit[col / 2] & 0x0F) + warp_idx] * T(17.0); } } smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = local_A[0]; #pragma unroll 32 for (int col = 0; col < 32; col++) smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = local_B[col]; } else if (warp_id < (WARPS - 1)) { local_A[0] = T(0.0); smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = 0.0f; #pragma unroll 32 for (int col = 0; col < 32; col++) local_B[col] = 0.0f; #pragma unroll 32 for (int col = 0; col < 32; col++) smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = 0.0f; } ticktock = ticktock == 0 ? 1 : 0; // if(threadIdx.x == 0) // printf("aa %i %i\n", idx, loaded_values); // for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) for (int base_idx = blockDim.x - 32; base_idx < K; base_idx += blockDim.x - 32) { idx = base_idx + threadIdx.x; // if(threadIdx.x == 0) // printf("%i %i\n", idx, loaded_values); //__syncthreads(); if (idx < K && warp_id < (WARPS - 1)) { if (loaded_values == 0) { local_A[0] = A[idx]; local_A[1] = A[idx + blockDim.x - 32]; #pragma unroll 32 for (int col = 0; col < 32; col++) { local_B_4bit[col] = B[(col_offset + col) * ldb + idx]; local_B_4bit[col + 16] = B[(col_offset + col) * ldb + idx]; } loaded_values = 1; } else { local_A[0] = local_A[1]; loaded_values--; int absidx = (idx + col_offset) / blocksize; half local_absmax = __ldg(&(absmax[absidx])); #pragma unroll 64 for (int col = 0; col < 64; col += 2) { // local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx); // local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx); // local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx); // local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx); // local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax); // local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax); local_B[col] = quant_map[(local_B_4bit[col / 2] >> 4)] * T(absidx); local_B[col + 1] = quant_map[(local_B_4bit[col / 2] & 0x0F)] * T(absidx); } // printnonzero(local_B, 128, ""); } smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = local_A[0]; #pragma unroll 32 for (int col = 0; col < 32; col++) smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = local_B[col]; } else if (warp_id < (WARPS - 1)) { local_A[0] = T(0.0); smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = 0.0f; #pragma unroll 32 for (int col = 0; col < 32; col++) local_B[col] = 0.0f; #pragma unroll 32 for (int col = 0; col < 32; col++) smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = 0.0f; } ticktock = ticktock == 0 ? 1 : 0; if (warp_id == (WARPS - 1)) for (int k = 0; k < batch_size_warps; k++) { wmma::load_matrix_sync( a_frag, &(smem_A[(ticktock * batch_size_warps + k) * a_tile_offset]), 16 ); // 111 mu wmma::load_matrix_sync( b_frag, &(smem_B[(ticktock * batch_size_warps + k) * b_tile_offset]), 16 ); // 35 mu wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); } } __syncthreads(); // if(threadIdx.x == 0) //{ // printnonzero(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: "); // printnonzero(smem_B, 2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1)), "B: "); // } if (warp_id != (WARPS - 1)) { return; } // only warp_id == (WARPS-1) from here int warp_lane = threadIdx.x % 32; ticktock = ticktock == 0 ? 1 : 0; for (int k = 0; k < batch_size_warps; k++) { // if(warp_lane == 0) // printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x); wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock * batch_size_warps + k) * a_tile_offset]), 16); // 111 mu wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock * batch_size_warps + k) * b_tile_offset]), 16); // 35 mu wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); } // 129 mu if (warp_id == (WARPS - 1)) wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major); // printnonzero(smem_C, 32, ""); if (col_offset + warp_lane < M) out[col_offset + warp_lane] = smem_C[warp_lane]; #endif } #define num_values_4bit 32 template __global__ void kgemm_4bit_inference_naive( int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, T* out, int lda, int ldb, int ldc, int blocksize ) { // per threadblock: // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] // 4 warps -> 4 loads per iter // 1x32 * 32x4 -> 1x4 outputs per thread block typedef cub::WarpReduce WarpReduce; __shared__ typename WarpReduce::TempStorage temp_storage[THREADS / 32]; const int warp_idx = threadIdx.x / 32; const int warp_lane = threadIdx.x % 32; const int row_B = (THREADS / 32) * blockIdx.x + warp_idx; const int offset_B = ldb * row_B; const int num_values_8bit = num_values_4bit / 2; float local_C = 0.0f; unsigned char local_B_4bit[num_values_8bit]; T local_B[num_values_4bit / 4]; T local_A[num_values_4bit / 4]; __shared__ T quant_map[16]; T local_absmax = T(0.0f); if (threadIdx.x < 16) quant_map[threadIdx.x] = T(__ldg(&datatype[threadIdx.x])); // for(int i = threadIdx.x; i < 16; i++) // quant_map[i] = T(__ldg(&datatype[i])); __syncthreads(); // A: [1, K] // B: [N, K] for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += 32 * num_values_4bit) { const int inner_idx_halved = inner_idx / 2; // Since blocksize will always be a power-of-2, we avoid more expensive // division by the blocksize and instead use a shift operation. // This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize. const int absidx = ((2 * offset_B) + inner_idx) >> (31 - __clz(blocksize)); local_absmax = __ldg(&(absmax[absidx])); if (row_B < M) { if ((inner_idx_halved + num_values_8bit) < (K / 2)) { // this is the most important for performance considerations reinterpret_cast(local_B_4bit)[0] = reinterpret_cast(B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)]; } else { #pragma unroll for (int j = 0; j < (num_values_8bit); j++) if ((inner_idx_halved) + j < (K / 2)) local_B_4bit[j] = B[offset_B + inner_idx_halved + j]; else local_B_4bit[j] = 0b01110111; } } else { #pragma unroll for (int j = 0; j < (num_values_8bit); j++) local_B_4bit[j] = 0b01110111; } for (int i = 0; i < 4; i++) { #pragma unroll for (int k = 0; k < num_values_8bit / 4; k++) { #if BNB_BF16_AVAILABLE local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax; local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax; #else // bf16 multipliation not supported local_B[k * 2] = T((float)quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * (float)local_absmax); local_B[k * 2 + 1] = T((float)quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * (float)local_absmax); #endif } if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) { // this is also relatively important for performance if (BITS == 16) { reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx / (num_values_4bit / 4) + i]; } else { reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 0]; reinterpret_cast(local_A)[1] = reinterpret_cast(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 1]; } } else #pragma unroll for (int k = 0; k < num_values_4bit / 4; k++) if (inner_idx + (i * num_values_4bit / 4) + k < K) local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)]; else local_A[k] = T(0.0f); // accumulate in float; small performance hit for Ampere, but lower error for outputs #pragma unroll for (int k = 0; k < num_values_4bit / 4; k++) { #if BNB_BF16_AVAILABLE local_C += (float)(local_A[k] * local_B[k]); #else // bf16 multipliation not supported local_C += ((float)local_A[k] * (float)local_B[k]); #endif } } } local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); if (row_B < M && warp_lane == 0) out[row_B] = T(local_C); } template __global__ void kfunc(T* A, T* B, T value, long n) { for (long i = (blockDim.x * blockIdx.x) + threadIdx.x; i < n; i += (blockDim.x * gridDim.x)) { switch (FUNC) { case FILL: A[i] = (T)value; break; case ARANGE: A[i] = (T)i; break; case _MUL: A[i] = A[i] * B[i]; break; } } } //============================================================== // TEMPLATE DEFINITIONS //============================================================== template __global__ void kfunc(float* A, float* B, float value, long n); template __global__ void kfunc(unsigned char* A, unsigned char* B, unsigned char value, long n); template __global__ void kfunc(float* A, float* B, float value, long n); template __global__ void kfunc(float* A, float* B, float value, long n); // these are not used and make no sense, but the compiler needs them // template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, // float * out, int lda, int ldb, int ldc); template __global__ void gemm_device( int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc ); template __global__ void gemm_device( int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc ); template __global__ void gemm_device( int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc ); template __global__ void gemm_device( int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc ); // template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, // float * out, int lda, int ldb, int ldc); template __global__ void gemm_device( int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc ); template __global__ void gemm_device( int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc ); template __global__ void gemm_device( int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc ); // these are not used and make no sense, but the compiler needs them // template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, // float * out, int lda, int ldb, int ldc); template __global__ void gemm_device( int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc ); template __global__ void gemm_device( int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc ); template __global__ void gemm_device( int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc ); template __global__ void gemm_device( int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc ); // template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, // float * out, int lda, int ldb, int ldc); template __global__ void gemm_device( int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc ); template __global__ void gemm_device( int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc ); template __global__ void gemm_device( int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc ); template __global__ void kgemm_4bit_inference( int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, half* out, int lda, int ldb, int ldc, int blocksize ); template __global__ void kgemm_4bit_inference( int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, half* out, int lda, int ldb, int ldc, int blocksize ); template __global__ void kgemm_4bit_inference( int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, half* out, int lda, int ldb, int ldc, int blocksize ); template __global__ void kgemm_4bit_inference( int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, half* out, int lda, int ldb, int ldc, int blocksize ); template __global__ void kgemm_4bit_inference_naive( int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, half* out, int lda, int ldb, int ldc, int blocksize ); template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128, 16>( int M, int N, int K, __nv_bfloat16* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, __nv_bfloat16* out, int lda, int ldb, int ldc, int blocksize ); template __global__ void kgemm_4bit_inference_naive( int M, int N, int K, float* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, float* out, int lda, int ldb, int ldc, int blocksize ); template __global__ void kspmm_coo_very_sparse_naive( int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out, float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB ); template __global__ void kspmm_coo_very_sparse_naive( int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out, float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB ); template __global__ void kspmm_coo_very_sparse_naive( int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out, float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB ); template __global__ void kspmm_coo_very_sparse_naive( int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out, float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB ); template __global__ void kspmm_coo_very_sparse_naive( int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out, float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB ); template __global__ void kspmm_coo_very_sparse_naive( int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out, float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB ); template __global__ void kdequant_mm_int32_fp16<4, 512>( int* __restrict__ const A, float* __restrict__ const rowStats, float* __restrict__ const colStats, half* out, half* __restrict__ const bias, const int numRows, const int numCols, const int n ); template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x); template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x); #define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ template __global__ void kPreconditionOptimizer32bit1State( \ gtype * g, gtype * p, float* state1, float* unorm, const float beta1, const float beta2, const float eps, \ const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n \ ); MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half) MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) MAKE_PreconditionOptimizer32bit1State(MOMENTUM, __nv_bfloat16) MAKE_PreconditionOptimizer32bit1State(RMSPROP, half) MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) MAKE_PreconditionOptimizer32bit1State(RMSPROP, __nv_bfloat16) MAKE_PreconditionOptimizer32bit1State(LION, half) MAKE_PreconditionOptimizer32bit1State(LION, float) MAKE_PreconditionOptimizer32bit1State(LION, __nv_bfloat16) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, __nv_bfloat16) #define MAKE_Optimizer32bit1State(oname, gtype) \ template __global__ void kOptimizer32bit1State( \ gtype * g, gtype * p, float* state1, float* unorm, const float max_unorm, const float param_norm, \ const float beta1, const float beta2, const float eps, const float weight_decay, const int step, \ const float lr, const float gnorm_scale, const bool skip_zeros, const int n \ ); MAKE_Optimizer32bit1State(MOMENTUM, half) MAKE_Optimizer32bit1State(MOMENTUM, float) MAKE_Optimizer32bit1State(MOMENTUM, __nv_bfloat16) MAKE_Optimizer32bit1State(RMSPROP, half) MAKE_Optimizer32bit1State(RMSPROP, float) MAKE_Optimizer32bit1State(RMSPROP, __nv_bfloat16) MAKE_Optimizer32bit1State(LION, half) MAKE_Optimizer32bit1State(LION, float) MAKE_Optimizer32bit1State(LION, __nv_bfloat16) MAKE_Optimizer32bit1State(ADAGRAD, half) MAKE_Optimizer32bit1State(ADAGRAD, float) MAKE_Optimizer32bit1State(ADAGRAD, __nv_bfloat16) #define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \ template __global__ void kPreconditionOptimizer32bit2State( \ gtype * g, gtype * p, float* state1, float* state2, float* unorm, const float beta1, const float beta2, \ const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, \ const int n \ ); MAKE_PreconditionOptimizer32bit2State(ADAM, float) MAKE_PreconditionOptimizer32bit2State(ADAM, half) MAKE_PreconditionOptimizer32bit2State(ADAM, __nv_bfloat16) MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, float) MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, half) MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, __nv_bfloat16) template __global__ void kOptimizer32bit2State( float* g, float* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n ); template __global__ void kOptimizer32bit2State( half* g, half* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n ); template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>( __nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n ); template __global__ void kOptimizer32bit2State( float* g, float* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n ); template __global__ void kOptimizer32bit2State( half* g, half* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n ); template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADEMAMIX>( __nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n ); #define MAKE_PreconditionStatic8bit1State(oname, gtype) \ template __global__ void kPreconditionOptimizerStatic8bit1State( \ gtype * p, gtype* __restrict__ const g, unsigned char* __restrict__ const state1, float* unorm, \ const float beta1, const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, \ float* max1, float* new_max1, const float weight_decay, const float gnorm_scale, const int n \ ); MAKE_PreconditionStatic8bit1State(MOMENTUM, half) MAKE_PreconditionStatic8bit1State(MOMENTUM, float) MAKE_PreconditionStatic8bit1State(RMSPROP, half) MAKE_PreconditionStatic8bit1State(RMSPROP, float) MAKE_PreconditionStatic8bit1State(LION, half) MAKE_PreconditionStatic8bit1State(LION, float) MAKE_PreconditionStatic8bit1State(ADAGRAD, half) MAKE_PreconditionStatic8bit1State(ADAGRAD, float) #define MAKE_optimizerStatic8bit1State(oname, gtype) \ template __global__ void kOptimizerStatic8bit1State( \ gtype * p, gtype* const g, unsigned char* state1, const float* unorm, const float max_unorm, \ const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr, \ float* __restrict__ const quantiles1, float* max1, float* new_max1, float weight_decay, \ const float gnorm_scale, const int n \ ); MAKE_optimizerStatic8bit1State(MOMENTUM, half) MAKE_optimizerStatic8bit1State(MOMENTUM, float) MAKE_optimizerStatic8bit1State(RMSPROP, half) MAKE_optimizerStatic8bit1State(RMSPROP, float) MAKE_optimizerStatic8bit1State(LION, half) MAKE_optimizerStatic8bit1State(LION, float) MAKE_optimizerStatic8bit1State(ADAGRAD, half) MAKE_optimizerStatic8bit1State(ADAGRAD, float) #define MAKE_PreconditionStatic8bit2State(oname, gtype) \ template __global__ void kPreconditionOptimizerStatic8bit2State( \ gtype * p, gtype* __restrict__ const g, unsigned char* __restrict__ const state1, \ unsigned char* __restrict__ const state2, float* unorm, const float beta1, const float beta2, const float eps, \ const int step, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, \ float* max2, float* new_max1, float* new_max2, const float gnorm_scale, const int n \ ); MAKE_PreconditionStatic8bit2State(ADAM, half) MAKE_PreconditionStatic8bit2State(ADAM, float) #define MAKE_optimizerStatic8bit2State(oname, gtype) \ template __global__ void kOptimizerStatic8bit2State( \ gtype * p, gtype* const g, unsigned char* state1, unsigned char* state2, const float* unorm, \ const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, \ const int step, const float lr, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, \ const int n \ ); MAKE_optimizerStatic8bit2State(ADAM, half) MAKE_optimizerStatic8bit2State(ADAM, float) template __global__ void kPercentileClipping(float* __restrict__ g, float* gnorm_vec, int step, const int n); template __global__ void kPercentileClipping(half* __restrict__ g, float* gnorm_vec, int step, const int n); #define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \ template __global__ void kQuantizeBlockwise( \ float* code, dtype* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, \ const int rand_offset, const int n \ ); MAKE_kQuantizeBlockwise(half, 4096, 4, 0, General8bit) MAKE_kQuantizeBlockwise(half, 4096, 4, 1, General8bit) MAKE_kQuantizeBlockwise(half, 2048, 4, 0, General8bit) MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit) MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit) MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit) MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit) MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit) MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4) MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4) MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4) MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4) MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4) MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4) MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4) MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4) MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4) MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4) MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4) MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4) MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4) MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4) MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit) MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit) MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit) MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit) MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit) MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit) MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit) MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4) MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4) MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4) MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4) MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4) MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4) MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4) MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4) MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4) MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4) MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4) MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4) MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 0, General8bit) MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 1, General8bit) MAKE_kQuantizeBlockwise(__nv_bfloat16, 2048, 4, 0, General8bit) MAKE_kQuantizeBlockwise(__nv_bfloat16, 1024, 4, 0, General8bit) MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, General8bit) MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, General8bit) MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, General8bit) MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, General8bit) MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 0, FP4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 2048, 4, 0, FP4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 1024, 4, 0, FP4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, FP4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, FP4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, FP4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, FP4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 0, NF4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 2048, 4, 0, NF4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 1024, 4, 0, NF4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, NF4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, NF4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, NF4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, NF4) template __global__ void kDequantizeBlockwise( float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n ); template __global__ void kDequantizeBlockwise( float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n ); template __global__ void kDequantizeBlockwise( float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n ); template __global__ void kDequantizeBlockwise( float* code, unsigned char* A, float* absmax, float* out, const int blocksize, const int n ); template __global__ void kDequantizeBlockwise( float* code, unsigned char* A, float* absmax, float* out, const int blocksize, const int n ); template __global__ void kDequantizeBlockwise( float* code, unsigned char* A, float* absmax, float* out, const int blocksize, const int n ); template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, FP4>( float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, const int blocksize, const int n ); template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, General8bit>( float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, const int blocksize, const int n ); template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>( float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, const int blocksize, const int n ); #define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ template __global__ void kOptimizerStatic8bit2StateBlockwise( \ gtype * p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, \ const float beta2, const float beta3, const float alpha, const float eps, const int step, const float lr, \ float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* absmax1, float* absmax2, \ float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n \ ); MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 256, 1) MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 256, 1) MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, __nv_bfloat16, 256, 1) MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, float, 256, 1) MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, half, 256, 1) MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, __nv_bfloat16, 256, 1) #define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \ template __global__ void kOptimizerStatic8bit1StateBlockwise( \ gtype * p, gtype* __restrict__ const g, unsigned char* state1, const float beta1, const float beta2, \ const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* absmax1, \ float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n \ ); MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, __nv_bfloat16, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, __nv_bfloat16, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, __nv_bfloat16, 256, 1) template __device__ void printnonzero(float* A, int num_values, const char* strval); template __device__ void printnonzero(half* A, int num_values, const char* strval);