Commit f1d9d80a authored by Titus von Koeller's avatar Titus von Koeller
Browse files

lib methods can now safely be assigned, no more cryptic errors on missing lib

parent fe6cd17e
...@@ -19,102 +19,101 @@ from .cextension import lib ...@@ -19,102 +19,101 @@ from .cextension import lib
name2qmap = {} name2qmap = {}
if lib and lib.compiled_with_cuda: """C FUNCTIONS FOR OPTIMIZERS"""
"""C FUNCTIONS FOR OPTIMIZERS""" str2optimizer32bit = {
str2optimizer32bit = { "adam": (
"adam": ( lib.cadam32bit_grad_fp32,
lib.cadam32bit_grad_fp32, lib.cadam32bit_grad_fp16,
lib.cadam32bit_grad_fp16, lib.cadam32bit_grad_bf16,
lib.cadam32bit_grad_bf16, ),
), "momentum": (
"momentum": ( lib.cmomentum32bit_grad_32,
lib.cmomentum32bit_grad_32, lib.cmomentum32bit_grad_16,
lib.cmomentum32bit_grad_16, ),
), "rmsprop": (
"rmsprop": ( lib.crmsprop32bit_grad_32,
lib.crmsprop32bit_grad_32, lib.crmsprop32bit_grad_16,
lib.crmsprop32bit_grad_16, ),
), "lion": (
"lion": ( lib.clion32bit_grad_fp32,
lib.clion32bit_grad_fp32, lib.clion32bit_grad_fp16,
lib.clion32bit_grad_fp16, lib.clion32bit_grad_bf16,
lib.clion32bit_grad_bf16, ),
), "adagrad": (
"adagrad": ( lib.cadagrad32bit_grad_32,
lib.cadagrad32bit_grad_32, lib.cadagrad32bit_grad_16,
lib.cadagrad32bit_grad_16, ),
), "lamb": (
"lamb": ( lib.cadam32bit_grad_fp32,
lib.cadam32bit_grad_fp32, lib.cadam32bit_grad_fp16,
lib.cadam32bit_grad_fp16, lib.cadam32bit_grad_bf16,
lib.cadam32bit_grad_bf16, ),
), "ademamix": (
"ademamix": ( lib.cademamix32bit_grad_fp32,
lib.cademamix32bit_grad_fp32, lib.cademamix32bit_grad_fp16,
lib.cademamix32bit_grad_fp16, lib.cademamix32bit_grad_bf16,
lib.cademamix32bit_grad_bf16, ),
), }
}
str2optimizer8bit = {
str2optimizer8bit = { "adam": (
"adam": ( lib.cadam_static_8bit_grad_32,
lib.cadam_static_8bit_grad_32, lib.cadam_static_8bit_grad_16,
lib.cadam_static_8bit_grad_16, ),
), "momentum": (
"momentum": ( lib.cmomentum_static_8bit_grad_32,
lib.cmomentum_static_8bit_grad_32, lib.cmomentum_static_8bit_grad_16,
lib.cmomentum_static_8bit_grad_16, ),
), "rmsprop": (
"rmsprop": ( lib.crmsprop_static_8bit_grad_32,
lib.crmsprop_static_8bit_grad_32, lib.crmsprop_static_8bit_grad_16,
lib.crmsprop_static_8bit_grad_16, ),
), "lion": (
"lion": ( lib.clion_static_8bit_grad_32,
lib.clion_static_8bit_grad_32, lib.clion_static_8bit_grad_16,
lib.clion_static_8bit_grad_16, ),
), "lamb": (
"lamb": ( lib.cadam_static_8bit_grad_32,
lib.cadam_static_8bit_grad_32, lib.cadam_static_8bit_grad_16,
lib.cadam_static_8bit_grad_16, ),
), "lars": (
"lars": ( lib.cmomentum_static_8bit_grad_32,
lib.cmomentum_static_8bit_grad_32, lib.cmomentum_static_8bit_grad_16,
lib.cmomentum_static_8bit_grad_16, ),
), }
}
str2optimizer8bit_blockwise = {
str2optimizer8bit_blockwise = { "adam": (
"adam": ( lib.cadam_8bit_blockwise_grad_fp32,
lib.cadam_8bit_blockwise_grad_fp32, lib.cadam_8bit_blockwise_grad_fp16,
lib.cadam_8bit_blockwise_grad_fp16, lib.cadam_8bit_blockwise_grad_bf16,
lib.cadam_8bit_blockwise_grad_bf16, ),
), "momentum": (
"momentum": ( lib.cmomentum_8bit_blockwise_grad_fp32,
lib.cmomentum_8bit_blockwise_grad_fp32, lib.cmomentum_8bit_blockwise_grad_fp16,
lib.cmomentum_8bit_blockwise_grad_fp16, lib.cmomentum_8bit_blockwise_grad_bf16,
lib.cmomentum_8bit_blockwise_grad_bf16, ),
), "rmsprop": (
"rmsprop": ( lib.crmsprop_8bit_blockwise_grad_fp32,
lib.crmsprop_8bit_blockwise_grad_fp32, lib.crmsprop_8bit_blockwise_grad_fp16,
lib.crmsprop_8bit_blockwise_grad_fp16, lib.crmsprop_8bit_blockwise_grad_bf16,
lib.crmsprop_8bit_blockwise_grad_bf16, ),
), "lion": (
"lion": ( lib.clion_8bit_blockwise_grad_fp32,
lib.clion_8bit_blockwise_grad_fp32, lib.clion_8bit_blockwise_grad_fp16,
lib.clion_8bit_blockwise_grad_fp16, lib.clion_8bit_blockwise_grad_bf16,
lib.clion_8bit_blockwise_grad_bf16, ),
), "adagrad": (
"adagrad": ( lib.cadagrad_8bit_blockwise_grad_fp32,
lib.cadagrad_8bit_blockwise_grad_fp32, lib.cadagrad_8bit_blockwise_grad_fp16,
lib.cadagrad_8bit_blockwise_grad_fp16, lib.cadagrad_8bit_blockwise_grad_bf16,
lib.cadagrad_8bit_blockwise_grad_bf16, ),
), "ademamix": (
"ademamix": ( lib.cademamix_8bit_blockwise_grad_fp32,
lib.cademamix_8bit_blockwise_grad_fp32, lib.cademamix_8bit_blockwise_grad_fp16,
lib.cademamix_8bit_blockwise_grad_fp16, lib.cademamix_8bit_blockwise_grad_bf16,
lib.cademamix_8bit_blockwise_grad_bf16, ),
), }
}
class GlobalPageManager: class GlobalPageManager:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment