Unverified Commit aa57bd89 authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Change 8bit optimizer blocksize 2048->256; additional bf16 support (#1365)

* Change 8bit optimizer blocksize 2048->256; additional bf16 support
* Update tolerances for 8bit optimizer tests
parent d9645465
...@@ -52,6 +52,7 @@ if lib and lib.compiled_with_cuda: ...@@ -52,6 +52,7 @@ if lib and lib.compiled_with_cuda:
"lamb": ( "lamb": (
lib.cadam32bit_grad_fp32, lib.cadam32bit_grad_fp32,
lib.cadam32bit_grad_fp16, lib.cadam32bit_grad_fp16,
lib.cadam32bit_grad_bf16,
), ),
"ademamix": ( "ademamix": (
lib.cademamix32bit_grad_fp32, lib.cademamix32bit_grad_fp32,
...@@ -96,10 +97,12 @@ if lib and lib.compiled_with_cuda: ...@@ -96,10 +97,12 @@ if lib and lib.compiled_with_cuda:
"momentum": ( "momentum": (
lib.cmomentum_8bit_blockwise_grad_fp32, lib.cmomentum_8bit_blockwise_grad_fp32,
lib.cmomentum_8bit_blockwise_grad_fp16, lib.cmomentum_8bit_blockwise_grad_fp16,
lib.cmomentum_8bit_blockwise_grad_bf16,
), ),
"rmsprop": ( "rmsprop": (
lib.crmsprop_8bit_blockwise_grad_fp32, lib.crmsprop_8bit_blockwise_grad_fp32,
lib.crmsprop_8bit_blockwise_grad_fp16, lib.crmsprop_8bit_blockwise_grad_fp16,
lib.crmsprop_8bit_blockwise_grad_bf16,
), ),
"lion": ( "lion": (
lib.clion_8bit_blockwise_grad_fp32, lib.clion_8bit_blockwise_grad_fp32,
...@@ -109,6 +112,7 @@ if lib and lib.compiled_with_cuda: ...@@ -109,6 +112,7 @@ if lib and lib.compiled_with_cuda:
"adagrad": ( "adagrad": (
lib.cadagrad_8bit_blockwise_grad_fp32, lib.cadagrad_8bit_blockwise_grad_fp32,
lib.cadagrad_8bit_blockwise_grad_fp16, lib.cadagrad_8bit_blockwise_grad_fp16,
lib.cadagrad_8bit_blockwise_grad_bf16,
), ),
"ademamix": ( "ademamix": (
lib.cademamix_8bit_blockwise_grad_fp32, lib.cademamix_8bit_blockwise_grad_fp32,
...@@ -398,7 +402,7 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): ...@@ -398,7 +402,7 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
data.append(0) data.append(0)
data.sort() data.sort()
return Tensor(data) return torch.tensor(data)
def create_quantile_map(A, total_bits=8): def create_quantile_map(A, total_bits=8):
......
...@@ -166,7 +166,7 @@ class AdEMAMix(Optimizer2State): ...@@ -166,7 +166,7 @@ class AdEMAMix(Optimizer2State):
self.name2qmap["udynamic"] = state["qmap2"] = self.name2qmap["udynamic"].to(p.device) self.name2qmap["udynamic"] = state["qmap2"] = self.name2qmap["udynamic"].to(p.device)
n = p.numel() n = p.numel()
blocks = (n // 2048) + bool(n % 2048) blocks = (n // 256) + bool(n % 256)
state["absmax1"] = torch.zeros((2, blocks), dtype=torch.float32, device=p.device) state["absmax1"] = torch.zeros((2, blocks), dtype=torch.float32, device=p.device)
state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
......
...@@ -477,8 +477,8 @@ class Optimizer2State(Optimizer8bit): ...@@ -477,8 +477,8 @@ class Optimizer2State(Optimizer8bit):
if config["block_wise"]: if config["block_wise"]:
n = p.numel() n = p.numel()
blocks = n // 2048 blocks = n // 256
blocks += 1 if n % 2048 > 0 else 0 blocks += 1 if n % 256 > 0 else 0
state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
...@@ -699,8 +699,8 @@ class Optimizer1State(Optimizer8bit): ...@@ -699,8 +699,8 @@ class Optimizer1State(Optimizer8bit):
if config["block_wise"]: if config["block_wise"]:
n = p.numel() n = p.numel()
blocks = n // 2048 blocks = n // 256
blocks += 1 if n % 2048 > 0 else 0 blocks += 1 if n % 256 > 0 else 0
state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
else: else:
......
...@@ -3829,13 +3829,16 @@ template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8 ...@@ -3829,13 +3829,16 @@ template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half) MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half)
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float)
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, __nv_bfloat16)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, half) MAKE_PreconditionOptimizer32bit1State(RMSPROP, half)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) MAKE_PreconditionOptimizer32bit1State(RMSPROP, float)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, __nv_bfloat16)
MAKE_PreconditionOptimizer32bit1State(LION, half) MAKE_PreconditionOptimizer32bit1State(LION, half)
MAKE_PreconditionOptimizer32bit1State(LION, float) MAKE_PreconditionOptimizer32bit1State(LION, float)
MAKE_PreconditionOptimizer32bit1State(LION, __nv_bfloat16) MAKE_PreconditionOptimizer32bit1State(LION, __nv_bfloat16)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, __nv_bfloat16)
#define MAKE_Optimizer32bit1State(oname, gtype) \ #define MAKE_Optimizer32bit1State(oname, gtype) \
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
...@@ -3843,13 +3846,16 @@ template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, ...@@ -3843,13 +3846,16 @@ template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p,
MAKE_Optimizer32bit1State(MOMENTUM, half) MAKE_Optimizer32bit1State(MOMENTUM, half)
MAKE_Optimizer32bit1State(MOMENTUM, float) MAKE_Optimizer32bit1State(MOMENTUM, float)
MAKE_Optimizer32bit1State(MOMENTUM, __nv_bfloat16)
MAKE_Optimizer32bit1State(RMSPROP, half) MAKE_Optimizer32bit1State(RMSPROP, half)
MAKE_Optimizer32bit1State(RMSPROP, float) MAKE_Optimizer32bit1State(RMSPROP, float)
MAKE_Optimizer32bit1State(RMSPROP, __nv_bfloat16)
MAKE_Optimizer32bit1State(LION, half) MAKE_Optimizer32bit1State(LION, half)
MAKE_Optimizer32bit1State(LION, float) MAKE_Optimizer32bit1State(LION, float)
MAKE_Optimizer32bit1State(LION, __nv_bfloat16) MAKE_Optimizer32bit1State(LION, __nv_bfloat16)
MAKE_Optimizer32bit1State(ADAGRAD, half) MAKE_Optimizer32bit1State(ADAGRAD, half)
MAKE_Optimizer32bit1State(ADAGRAD, float) MAKE_Optimizer32bit1State(ADAGRAD, float)
MAKE_Optimizer32bit1State(ADAGRAD, __nv_bfloat16)
#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \ #define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \
template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \ template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
...@@ -3950,6 +3956,8 @@ MAKE_optimizerStatic8bit2State(ADAM, float) ...@@ -3950,6 +3956,8 @@ MAKE_optimizerStatic8bit2State(ADAM, float)
template __global__ void kPercentileClipping<float, 2048, 4>(float * __restrict__ g, float *gnorm_vec, int step, const int n); template __global__ void kPercentileClipping<float, 2048, 4>(float * __restrict__ g, float *gnorm_vec, int step, const int n);
template __global__ void kPercentileClipping<half, 2048, 4>(half * __restrict__ g, float *gnorm_vec, int step, const int n); template __global__ void kPercentileClipping<half, 2048, 4>(half * __restrict__ g, float *gnorm_vec, int step, const int n);
// template __global__ void kPercentileClipping<float, 128, 4>(float * __restrict__ g, float *gnorm_vec, int step, const int n);
// template __global__ void kPercentileClipping<half, 128, 4>(half * __restrict__ g, float *gnorm_vec, int step, const int n);
#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \ #define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \
template __global__ void kQuantizeBlockwise<dtype, blocksize, num_per_thread, stochastic, data_type_name>(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \ template __global__ void kQuantizeBlockwise<dtype, blocksize, num_per_thread, stochastic, data_type_name>(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \
...@@ -4041,13 +4049,12 @@ template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block ...@@ -4041,13 +4049,12 @@ template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block
float weight_decay, \ float weight_decay, \
const float gnorm_scale, const bool skip_zeros, const int n); \ const float gnorm_scale, const bool skip_zeros, const int n); \
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 2048, 8) MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 256, 1)
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 2048, 8) MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 256, 1)
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, __nv_bfloat16, 2048, 8) MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, __nv_bfloat16, 256, 1)
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, float, 2048, 8) MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, float, 256, 1)
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, half, 2048, 8) MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, half, 256, 1)
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, __nv_bfloat16, 2048, 8) MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, __nv_bfloat16, 256, 1)
#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \ #define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \
template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block_size, num_per_thread>( \ template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block_size, num_per_thread>( \
...@@ -4059,15 +4066,18 @@ template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block ...@@ -4059,15 +4066,18 @@ template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block
float weight_decay, \ float weight_decay, \
const float gnorm_scale, const bool skip_zeros, const int n); \ const float gnorm_scale, const bool skip_zeros, const int n); \
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, __nv_bfloat16, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, __nv_bfloat16, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, __nv_bfloat16, 256, 1)
template __device__ void printnonzero<float>(float *A, int num_values, const char*strval); template __device__ void printnonzero<float>(float *A, int num_values, const char*strval);
template __device__ void printnonzero<half>(half *A, int num_values, const char*strval); template __device__ void printnonzero<half>(half *A, int num_values, const char*strval);
...@@ -191,10 +191,10 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, ...@@ -191,10 +191,10 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
} }
} }
#define BLOCKSIZE_2STATE 2048 #define BLOCKSIZE_2STATE 256
#define NUM_2STATE 8 #define NUM_2STATE 1
#define BLOCKSIZE_1STATE 2048 #define BLOCKSIZE_1STATE 256
#define NUM_1STATE 8 #define NUM_1STATE 1
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise( template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(
T* p, T* p,
...@@ -818,13 +818,16 @@ MAKE_optimizer32bit(ADAM, float) ...@@ -818,13 +818,16 @@ MAKE_optimizer32bit(ADAM, float)
MAKE_optimizer32bit(ADAM, __nv_bfloat16) MAKE_optimizer32bit(ADAM, __nv_bfloat16)
MAKE_optimizer32bit(MOMENTUM, half) MAKE_optimizer32bit(MOMENTUM, half)
MAKE_optimizer32bit(MOMENTUM, float) MAKE_optimizer32bit(MOMENTUM, float)
MAKE_optimizer32bit(MOMENTUM, __nv_bfloat16)
MAKE_optimizer32bit(RMSPROP, half) MAKE_optimizer32bit(RMSPROP, half)
MAKE_optimizer32bit(RMSPROP, float) MAKE_optimizer32bit(RMSPROP, float)
MAKE_optimizer32bit(RMSPROP, __nv_bfloat16)
MAKE_optimizer32bit(LION, half) MAKE_optimizer32bit(LION, half)
MAKE_optimizer32bit(LION, float) MAKE_optimizer32bit(LION, float)
MAKE_optimizer32bit(LION, __nv_bfloat16) MAKE_optimizer32bit(LION, __nv_bfloat16)
MAKE_optimizer32bit(ADAGRAD, half) MAKE_optimizer32bit(ADAGRAD, half)
MAKE_optimizer32bit(ADAGRAD, float) MAKE_optimizer32bit(ADAGRAD, float)
MAKE_optimizer32bit(ADAGRAD, __nv_bfloat16)
MAKE_optimizer32bit(ADEMAMIX, half) MAKE_optimizer32bit(ADEMAMIX, half)
MAKE_optimizer32bit(ADEMAMIX, __nv_bfloat16) MAKE_optimizer32bit(ADEMAMIX, __nv_bfloat16)
MAKE_optimizer32bit(ADEMAMIX, float) MAKE_optimizer32bit(ADEMAMIX, float)
...@@ -861,13 +864,16 @@ MAKE_optimizerStatic8bitBlockwise(float, ADAM); ...@@ -861,13 +864,16 @@ MAKE_optimizerStatic8bitBlockwise(float, ADAM);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM); MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM);
MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); MAKE_optimizerStatic8bitBlockwise(float, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(half, LION); MAKE_optimizerStatic8bitBlockwise(half, LION);
MAKE_optimizerStatic8bitBlockwise(float, LION); MAKE_optimizerStatic8bitBlockwise(float, LION);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, LION); MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, LION);
MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD);
MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAGRAD);
MAKE_optimizerStatic8bitBlockwise(half, ADEMAMIX); MAKE_optimizerStatic8bitBlockwise(half, ADEMAMIX);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADEMAMIX); MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADEMAMIX);
MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX); MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX);
......
...@@ -103,19 +103,22 @@ void fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \ ...@@ -103,19 +103,22 @@ void fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \
{ optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\ { optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\
MAKE_BLOCKWISE8(adam, ADAM, half, fp16) MAKE_BLOCKWISE8(adam, ADAM, half, fp16)
MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(adam, ADAM, float, fp32) MAKE_BLOCKWISE8(adam, ADAM, float, fp32)
MAKE_BLOCKWISE8(momentum, MOMENTUM, half, fp16) MAKE_BLOCKWISE8(momentum, MOMENTUM, half, fp16)
MAKE_BLOCKWISE8(momentum, MOMENTUM, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(momentum, MOMENTUM, float, fp32) MAKE_BLOCKWISE8(momentum, MOMENTUM, float, fp32)
MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, fp16) MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, fp16)
MAKE_BLOCKWISE8(rmsprop, RMSPROP, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, fp32) MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, fp32)
MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, fp16) MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, fp16)
MAKE_BLOCKWISE8(adagrad, ADAGRAD, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32) MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32)
MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(lion, LION, half, fp16) MAKE_BLOCKWISE8(lion, LION, half, fp16)
MAKE_BLOCKWISE8(lion, LION, float, fp32)
MAKE_BLOCKWISE8(lion, LION, __nv_bfloat16, bf16) MAKE_BLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16) MAKE_BLOCKWISE8(lion, LION, float, fp32)
MAKE_BLOCKWISE8(ademamix, ADEMAMIX, half, fp16) MAKE_BLOCKWISE8(ademamix, ADEMAMIX, half, fp16)
MAKE_BLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(ademamix, ADEMAMIX, float, fp32) MAKE_BLOCKWISE8(ademamix, ADEMAMIX, float, fp32)
...@@ -283,13 +286,16 @@ extern "C" ...@@ -283,13 +286,16 @@ extern "C"
MAKE_CBLOCKWISE8(adam, ADAM, half, fp16) MAKE_CBLOCKWISE8(adam, ADAM, half, fp16)
MAKE_CBLOCKWISE8(adam, ADAM, float, fp32) MAKE_CBLOCKWISE8(adam, ADAM, float, fp32)
MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, fp16) MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, fp16)
MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, fp32) MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, fp32)
MAKE_CBLOCKWISE8(momentum, MOMENTUM, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, fp16) MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, fp16)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32) MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16) MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32) MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32)
MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) MAKE_CBLOCKWISE8(adagrad, ADAGRAD, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(lion, LION, half, fp16) MAKE_CBLOCKWISE8(lion, LION, half, fp16)
MAKE_CBLOCKWISE8(lion, LION, float, fp32) MAKE_CBLOCKWISE8(lion, LION, float, fp32)
MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16) MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
......
...@@ -74,10 +74,18 @@ str2optimizers["ademamix_scheduled"] = ( ...@@ -74,10 +74,18 @@ str2optimizers["ademamix_scheduled"] = (
lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=k, t_beta3=k), lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=k, t_beta3=k),
lambda pxx: bnb.optim.AdEMAMix(pxx, t_alpha=k, t_beta3=k), lambda pxx: bnb.optim.AdEMAMix(pxx, t_alpha=k, t_beta3=k),
) )
str2optimizers["paged_ademamix_scheduled"] = (
lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=k, t_beta3=k),
lambda pxx: bnb.optim.PagedAdEMAMix(pxx, t_alpha=k, t_beta3=k),
)
str2optimizers["ademamix8bit_blockwise_scheduled"] = ( str2optimizers["ademamix8bit_blockwise_scheduled"] = (
lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=100, t_beta3=100), lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=100, t_beta3=100),
lambda pxx: bnb.optim.AdEMAMix8bit(pxx, t_alpha=100, t_beta3=100), lambda pxx: bnb.optim.AdEMAMix8bit(pxx, t_alpha=100, t_beta3=100),
) )
str2optimizers["paged_ademamix8bit_blockwise_scheduled"] = (
lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=100, t_beta3=100),
lambda pxx: bnb.optim.PagedAdEMAMix8bit(pxx, t_alpha=100, t_beta3=100),
)
str2optimizers["lion"] = (Lion, bnb.optim.Lion) str2optimizers["lion"] = (Lion, bnb.optim.Lion)
str2optimizers["lion8bit"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False)) str2optimizers["lion8bit"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False))
...@@ -143,7 +151,7 @@ str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1" ...@@ -143,7 +151,7 @@ str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"
str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
str2statenames["ademamix"] = str2statenames["ademamix_scheduled"] = [("m1_m2", "state1"), ("nu", "state2")] str2statenames["ademamix"] = str2statenames["ademamix_scheduled"] = [("m1_m2", "state1"), ("nu", "state2")]
str2statenames["paged_ademamix"] = [("m1_m2", "state1"), ("nu", "state2")] str2statenames["paged_ademamix"] = str2statenames["paged_ademamix_scheduled"] = [("m1_m2", "state1"), ("nu", "state2")]
str2statenames["ademamix8bit_blockwise"] = str2statenames["ademamix8bit_blockwise_scheduled"] = [ str2statenames["ademamix8bit_blockwise"] = str2statenames["ademamix8bit_blockwise_scheduled"] = [
("m1_m2", "state1", "qmap1", "absmax1"), ("m1_m2", "state1", "qmap1", "absmax1"),
("nu", "state2", "qmap2", "absmax2"), ("nu", "state2", "qmap2", "absmax2"),
...@@ -164,6 +172,7 @@ optimizer_names_32bit = [ ...@@ -164,6 +172,7 @@ optimizer_names_32bit = [
"ademamix", "ademamix",
"ademamix_scheduled", "ademamix_scheduled",
"paged_ademamix", "paged_ademamix",
"paged_ademamix_scheduled",
] ]
...@@ -309,18 +318,15 @@ optimizer_names_8bit = [ ...@@ -309,18 +318,15 @@ optimizer_names_8bit = [
def test_optimizer8bit(dim1, dim2, gtype, optim_name): def test_optimizer8bit(dim1, dim2, gtype, optim_name):
torch.set_printoptions(precision=6) torch.set_printoptions(precision=6)
if gtype == torch.bfloat16 and optim_name not in [ if gtype == torch.bfloat16 and "blockwise" not in optim_name:
"adam8bit_blockwise",
"lion8bit_blockwise",
"ademamix8bit_blockwise",
]:
pytest.skip() pytest.skip()
if dim1 == 1 and dim2 == 1: if dim1 == 1 and dim2 == 1:
return return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
p2 = p1.clone() p2 = p1.clone()
p1 = p1.float() p1 = p1.float()
blocksize = 2048 blocksize = 256
torch_optimizer = str2optimizers[optim_name][0]([p1]) torch_optimizer = str2optimizers[optim_name][0]([p1])
bnb_optimizer = str2optimizers[optim_name][1]([p2]) bnb_optimizer = str2optimizers[optim_name][1]([p2])
...@@ -347,8 +353,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ...@@ -347,8 +353,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
torch_optimizer.step() torch_optimizer.step()
# since Lion can have pretty noisy updates where things lie at the boundary # since Lion can have pretty noisy updates where things lie at the boundary
# and AdEMAMix can diverge as well, allow up to 0.05% errors. assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=0)
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=int(p1.numel() * 5e-4))
dequant_states = [] dequant_states = []
for name1, name2, qmap, max_val in str2statenames[optim_name]: for name1, name2, qmap, max_val in str2statenames[optim_name]:
...@@ -392,11 +397,11 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ...@@ -392,11 +397,11 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
err = torch.abs(p1 - p2) err = torch.abs(p1 - p2)
relerr = err / (torch.abs(p1) + 1e-9) relerr = err / (torch.abs(p1) + 1e-9)
if g.dtype == torch.bfloat16: if g.dtype == torch.bfloat16:
assert err.mean() < 0.00015 assert err.mean() <= 0.00017
assert relerr.mean() < 0.0020 # 0.0016 assert relerr.mean() <= 0.0016
else: else:
assert err.mean() < 0.00016 # 0.00012 assert err.mean() < 0.00006
assert relerr.mean() < 0.0016 # 0.0012 assert relerr.mean() < 0.0006
errors.append(err.mean().item()) errors.append(err.mean().item())
relerrors.append(relerr.mean().item()) relerrors.append(relerr.mean().item())
...@@ -454,9 +459,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ...@@ -454,9 +459,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0 num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0
assert num_not_close.sum().item() < 20 assert num_not_close.sum().item() < 20
# since Lion can have pretty noisy updates where things lie at the boundary
# and AdEMAMix can also be noisy, allow up to 0.05%. # Lion can have pretty noisy updates where things lie at the boundary
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=int(p1.numel() * 5e-04)) assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=0)
# the parameters diverge quickly. Here we keep them close # the parameters diverge quickly. Here we keep them close
# together so we can test against the Adam error # together so we can test against the Adam error
...@@ -560,7 +565,11 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): ...@@ -560,7 +565,11 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
optimizer_names_benchmark = [ optimizer_names_benchmark = [
"adam8bit_blockwise", "adam8bit_blockwise",
"paged_adam8bit_blockwise", "paged_adam8bit_blockwise",
"paged_adamw8bit_blockwise", "ademamix8bit_blockwise",
"paged_ademamix8bit_blockwise",
"ademamix8bit_blockwise_scheduled",
"paged_ademamix8bit_blockwise_scheduled",
"lion8bit_blockwise",
"paged_lion8bit_blockwise", "paged_lion8bit_blockwise",
"paged_ademamix8bit_blockwise", "paged_ademamix8bit_blockwise",
] ]
...@@ -568,7 +577,7 @@ optimizer_names_benchmark = [ ...@@ -568,7 +577,7 @@ optimizer_names_benchmark = [
@pytest.mark.parametrize("dim1", [4096], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [4096], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [4096], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim2", [4096], ids=id_formatter("dim2"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype) @pytest.mark.parametrize("gtype", [torch.float32, torch.bfloat16, torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("optim_name", optimizer_names_benchmark, ids=id_formatter("opt")) @pytest.mark.parametrize("optim_name", optimizer_names_benchmark, ids=id_formatter("opt"))
@pytest.mark.benchmark @pytest.mark.benchmark
def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
...@@ -580,8 +589,9 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): ...@@ -580,8 +589,9 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01 g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
p1.grad = g p1.grad = g
for i in range(k): total_steps = 500
if i == k // 5: for i in range(total_steps):
if i == total_steps // 5:
# 100 iterations for burn-in # 100 iterations for burn-in
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
...@@ -591,8 +601,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): ...@@ -591,8 +601,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
torch.cuda.synchronize() torch.cuda.synchronize()
s = time.time() - t0 s = time.time() - t0
print("") print("")
params = (k - k // 5) * dim1 * dim2 params = (total_steps - total_steps // 5) * dim1 * dim2
print(optim_name, gtype, s / params) print(optim_name, gtype, s, params, s / params)
# assert s < 3.9 # assert s < 3.9
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment