Commit 1b8772a8 authored by Tim Dettmers's avatar Tim Dettmers
Browse files

Added PagedLion and bf16 Lion.

parent 2bce175d
...@@ -37,10 +37,7 @@ if COMPILED_WITH_CUDA: ...@@ -37,10 +37,7 @@ if COMPILED_WITH_CUDA:
lib.crmsprop32bit_grad_32, lib.crmsprop32bit_grad_32,
lib.crmsprop32bit_grad_16, lib.crmsprop32bit_grad_16,
) )
str2optimizer32bit["lion"] = ( str2optimizer32bit["lion"] = (lib.clion32bit_grad_fp32, lib.clion32bit_grad_fp16, lib.clion32bit_grad_bf16)
lib.clion32bit_grad_32,
lib.clion32bit_grad_16,
)
str2optimizer32bit["adagrad"] = ( str2optimizer32bit["adagrad"] = (
lib.cadagrad32bit_grad_32, lib.cadagrad32bit_grad_32,
lib.cadagrad32bit_grad_16, lib.cadagrad32bit_grad_16,
...@@ -89,6 +86,7 @@ if COMPILED_WITH_CUDA: ...@@ -89,6 +86,7 @@ if COMPILED_WITH_CUDA:
str2optimizer8bit_blockwise["lion"] = ( str2optimizer8bit_blockwise["lion"] = (
lib.clion_8bit_blockwise_grad_fp32, lib.clion_8bit_blockwise_grad_fp32,
lib.clion_8bit_blockwise_grad_fp16, lib.clion_8bit_blockwise_grad_fp16,
lib.clion_8bit_blockwise_grad_bf16,
) )
str2optimizer8bit_blockwise["adagrad"] = ( str2optimizer8bit_blockwise["adagrad"] = (
lib.cadagrad_8bit_blockwise_grad_fp32, lib.cadagrad_8bit_blockwise_grad_fp32,
......
...@@ -12,5 +12,5 @@ from .lamb import LAMB, LAMB8bit, LAMB32bit ...@@ -12,5 +12,5 @@ from .lamb import LAMB, LAMB8bit, LAMB32bit
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .optimizer import GlobalOptimManager from .optimizer import GlobalOptimManager
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
from .lion import Lion, Lion8bit, Lion32bit from .lion import Lion, Lion8bit, Lion32bit, PagedLion, PagedLion8bit, PagedLion32bit
from .sgd import SGD, SGD8bit, SGD32bit from .sgd import SGD, SGD8bit, SGD32bit
...@@ -4,84 +4,27 @@ ...@@ -4,84 +4,27 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State from bitsandbytes.optim.optimizer import Optimizer1State
class Lion(Optimizer1State): class Lion(Optimizer1State):
def __init__( def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
self, super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"lion",
params,
lr,
betas,
0.,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class Lion8bit(Optimizer1State): class Lion8bit(Optimizer1State):
def __init__( def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
self, super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"lion",
params,
lr,
betas,
0.,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class Lion32bit(Optimizer1State): class Lion32bit(Optimizer1State):
def __init__( def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
self, super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
params,
lr=1e-4,
betas=(0.9, 0.99), class PagedLion(Optimizer1State):
weight_decay=0, def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
args=None, super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
min_8bit_size=4096,
percentile_clipping=100, class PagedLion8bit(Optimizer1State):
block_wise=True, def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
): super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
super().__init__(
"lion", class PagedLion32bit(Optimizer1State):
params, def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
lr, super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
betas,
0.,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
...@@ -3666,6 +3666,7 @@ MAKE_PreconditionOptimizer32bit1State(RMSPROP, half) ...@@ -3666,6 +3666,7 @@ MAKE_PreconditionOptimizer32bit1State(RMSPROP, half)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) MAKE_PreconditionOptimizer32bit1State(RMSPROP, float)
MAKE_PreconditionOptimizer32bit1State(LION, half) MAKE_PreconditionOptimizer32bit1State(LION, half)
MAKE_PreconditionOptimizer32bit1State(LION, float) MAKE_PreconditionOptimizer32bit1State(LION, float)
MAKE_PreconditionOptimizer32bit1State(LION, __nv_bfloat16)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)
...@@ -3679,6 +3680,7 @@ MAKE_Optimizer32bit1State(RMSPROP, half) ...@@ -3679,6 +3680,7 @@ MAKE_Optimizer32bit1State(RMSPROP, half)
MAKE_Optimizer32bit1State(RMSPROP, float) MAKE_Optimizer32bit1State(RMSPROP, float)
MAKE_Optimizer32bit1State(LION, half) MAKE_Optimizer32bit1State(LION, half)
MAKE_Optimizer32bit1State(LION, float) MAKE_Optimizer32bit1State(LION, float)
MAKE_Optimizer32bit1State(LION, __nv_bfloat16)
MAKE_Optimizer32bit1State(ADAGRAD, half) MAKE_Optimizer32bit1State(ADAGRAD, half)
MAKE_Optimizer32bit1State(ADAGRAD, float) MAKE_Optimizer32bit1State(ADAGRAD, float)
...@@ -3852,5 +3854,6 @@ MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8) ...@@ -3852,5 +3854,6 @@ MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8)
...@@ -802,6 +802,7 @@ MAKE_optimizer32bit(RMSPROP, half) ...@@ -802,6 +802,7 @@ MAKE_optimizer32bit(RMSPROP, half)
MAKE_optimizer32bit(RMSPROP, float) MAKE_optimizer32bit(RMSPROP, float)
MAKE_optimizer32bit(LION, half) MAKE_optimizer32bit(LION, half)
MAKE_optimizer32bit(LION, float) MAKE_optimizer32bit(LION, float)
MAKE_optimizer32bit(LION, __nv_bfloat16)
MAKE_optimizer32bit(ADAGRAD, half) MAKE_optimizer32bit(ADAGRAD, half)
MAKE_optimizer32bit(ADAGRAD, float) MAKE_optimizer32bit(ADAGRAD, float)
...@@ -837,6 +838,7 @@ MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); ...@@ -837,6 +838,7 @@ MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); MAKE_optimizerStatic8bitBlockwise(float, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(half, LION); MAKE_optimizerStatic8bitBlockwise(half, LION);
MAKE_optimizerStatic8bitBlockwise(float, LION); MAKE_optimizerStatic8bitBlockwise(float, LION);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, LION);
MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD);
MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);
......
...@@ -51,8 +51,9 @@ MAKE_FUNC32(adam, ADAM, half, fp16) ...@@ -51,8 +51,9 @@ MAKE_FUNC32(adam, ADAM, half, fp16)
MAKE_FUNC32(adam, ADAM, __nv_bfloat16, bf16) MAKE_FUNC32(adam, ADAM, __nv_bfloat16, bf16)
MAKE_FUNC32(rmsprop, RMSPROP, float, 32) MAKE_FUNC32(rmsprop, RMSPROP, float, 32)
MAKE_FUNC32(rmsprop, RMSPROP, half, 16) MAKE_FUNC32(rmsprop, RMSPROP, half, 16)
MAKE_FUNC32(lion, LION, float, 32) MAKE_FUNC32(lion, LION, float, fp32)
MAKE_FUNC32(lion, LION, half, 16) MAKE_FUNC32(lion, LION, half, fp16)
MAKE_FUNC32(lion, LION, __nv_bfloat16, bf16)
MAKE_FUNC32(adagrad, ADAGRAD, float, 32) MAKE_FUNC32(adagrad, ADAGRAD, float, 32)
MAKE_FUNC32(adagrad, ADAGRAD, half, 16) MAKE_FUNC32(adagrad, ADAGRAD, half, 16)
...@@ -95,6 +96,7 @@ MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32) ...@@ -95,6 +96,7 @@ MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32)
MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(lion, LION, half, fp16) MAKE_BLOCKWISE8(lion, LION, half, fp16)
MAKE_BLOCKWISE8(lion, LION, float, fp32) MAKE_BLOCKWISE8(lion, LION, float, fp32)
MAKE_BLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); } void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); }
...@@ -201,8 +203,9 @@ extern "C" ...@@ -201,8 +203,9 @@ extern "C"
MAKE_CFUNC32(momentum, half, 16) MAKE_CFUNC32(momentum, half, 16)
MAKE_CFUNC32(rmsprop, float, 32) MAKE_CFUNC32(rmsprop, float, 32)
MAKE_CFUNC32(rmsprop, half, 16) MAKE_CFUNC32(rmsprop, half, 16)
MAKE_CFUNC32(lion, float, 32) MAKE_CFUNC32(lion, float, fp32)
MAKE_CFUNC32(lion, half, 16) MAKE_CFUNC32(lion, half, fp16)
MAKE_CFUNC32(lion, __nv_bfloat16, bf16)
MAKE_CFUNC32(adagrad, float, 32) MAKE_CFUNC32(adagrad, float, 32)
MAKE_CFUNC32(adagrad, half, 16) MAKE_CFUNC32(adagrad, half, 16)
...@@ -245,6 +248,7 @@ extern "C" ...@@ -245,6 +248,7 @@ extern "C"
MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(lion, LION, half, fp16) MAKE_CBLOCKWISE8(lion, LION, half, fp16)
MAKE_CBLOCKWISE8(lion, LION, float, fp32) MAKE_CBLOCKWISE8(lion, LION, float, fp32)
MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); } void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); }
void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); } void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); }
......
...@@ -19,11 +19,11 @@ import bitsandbytes.functional as F ...@@ -19,11 +19,11 @@ import bitsandbytes.functional as F
k = 20 k = 20
def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0): def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0):
idx = torch.isclose(a, b, rtol, atol) idx = torch.isclose(a, b, rtol=rtol, atol=atol)
error_count = (idx == 0).sum().item() error_count = (idx == 0).sum().item()
if error_count > max_error_count: if error_count > max_error_count:
print(f"Too many values not close: assert {error_count} < {max_error_count}") print(f"Too many values not close: assert {error_count} < {max_error_count}")
torch.testing.assert_close(a, b, rtol, atol) torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
def get_temp_dir(): def get_temp_dir():
...@@ -35,13 +35,8 @@ def get_temp_dir(): ...@@ -35,13 +35,8 @@ def get_temp_dir():
def rm_path(path): def rm_path(path):
shutil.rmtree(path) shutil.rmtree(path)
str2bf16support = {}
str2bf16support['adam8bit_blockwise'] = True
str2optimizers = {} str2optimizers = {}
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam) str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion) str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
str2optimizers["momentum_pytorch"] = ( str2optimizers["momentum_pytorch"] = (
None, None,
...@@ -51,8 +46,8 @@ str2optimizers["momentum_pytorch"] = ( ...@@ -51,8 +46,8 @@ str2optimizers["momentum_pytorch"] = (
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam) str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW) str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW)
str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam) str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam)
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers["lion"] = (Lion, bnb.optim.Lion) str2optimizers["lion"] = (Lion, bnb.optim.Lion)
str2optimizers["paged_lion"] = (Lion, bnb.optim.PagedLion)
str2optimizers["momentum"] = ( str2optimizers["momentum"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
...@@ -76,6 +71,7 @@ str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim. ...@@ -76,6 +71,7 @@ str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.
str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True)) str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True))
str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True)) str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True))
str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True)) str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True))
str2optimizers["momentum8bit_blockwise"] = ( str2optimizers["momentum8bit_blockwise"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
...@@ -90,6 +86,7 @@ str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] ...@@ -90,6 +86,7 @@ str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["lion"] = [("exp_avg", "state1")] str2statenames["lion"] = [("exp_avg", "state1")]
str2statenames["paged_lion"] = [("exp_avg", "state1")]
str2statenames["momentum"] = [("momentum_buffer", "state1")] str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["rmsprop"] = [("square_avg", "state1")] str2statenames["rmsprop"] = [("square_avg", "state1")]
...@@ -104,15 +101,17 @@ str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1 ...@@ -104,15 +101,17 @@ str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")] str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")] str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
dim1 = [1024] dim1 = [1024]
dim2 = [32, 1024, 4097, 1] dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16] gtype = [torch.float32, torch.float16, torch.bfloat16]
optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion'] optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion', 'paged_lion']
values = list(product(dim1, dim2, gtype, optimizer_names)) values = list(product(dim1, dim2, gtype, optimizer_names))
names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values] names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer32bit(dim1, dim2, gtype, optim_name): def test_optimizer32bit(dim1, dim2, gtype, optim_name):
if gtype == torch.bfloat16 and optim_name in ['momentum', 'rmsprop']: pytest.skip()
if dim1 == 1 and dim2 == 1: if dim1 == 1 and dim2 == 1:
return return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
...@@ -254,7 +253,7 @@ names = [ ...@@ -254,7 +253,7 @@ names = [
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer8bit(dim1, dim2, gtype, optim_name): def test_optimizer8bit(dim1, dim2, gtype, optim_name):
if gtype == torch.bfloat16 and optim_name not in str2bf16support: return if gtype == torch.bfloat16 and optim_name not in ['adam8bit_blockwise', 'lion8bit_blockwise']: pytest.skip()
if dim1 == 1 and dim2 == 1: if dim1 == 1 and dim2 == 1:
return return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
...@@ -485,7 +484,7 @@ gtype = [torch.float32, torch.float16] ...@@ -485,7 +484,7 @@ gtype = [torch.float32, torch.float16]
# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch'] # optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
# optimizer_names = ['lamb_apex', 'lamb8bit'] # optimizer_names = ['lamb_apex', 'lamb8bit']
# optimizer_names = ['lars_apex', 'lars8bit'] # optimizer_names = ['lars_apex', 'lars8bit']
optimizer_names = ["adam8bit_blockwise", 'paged_adam8bit_blockwise', 'paged_adamw8bit_blockwise'] optimizer_names = ["adam8bit_blockwise", 'paged_adam8bit_blockwise', 'paged_adamw8bit_blockwise', 'paged_lion8bit_blockwise']
values = list(product(dim1, dim2, gtype, optimizer_names)) values = list(product(dim1, dim2, gtype, optimizer_names))
names = [ names = [
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment