Commit 4b025748 authored by Matthew Douglas's avatar Matthew Douglas
Browse files

Lint fix

parent 1813b058
......@@ -320,6 +320,7 @@ name2optimizer_id = {
"ademamix": ADEMAMIX,
}
@torch.compile
def _optimizer_precondition_32bit(
g: torch.Tensor,
......@@ -525,29 +526,53 @@ def _(
if optimizer_name == "lion":
_optimizer_update_32bit(
g, p, state1, state2, unorm_vec, max_unorm, param_norm,
beta1, beta2, beta3, alpha, eps, weight_decay, step,
lr, gnorm_scale, optimizer_id
g,
p,
state1,
state2,
unorm_vec,
max_unorm,
param_norm,
beta1,
beta2,
beta3,
alpha,
eps,
weight_decay,
step,
lr,
gnorm_scale,
optimizer_id,
)
if max_unorm > 0.0:
unorm_vec.zero_()
_optimizer_precondition_32bit(
g, p, state1, state2, unorm_vec,
beta1, beta2, eps, weight_decay, step,
lr, gnorm_scale, optimizer_id
g, p, state1, state2, unorm_vec, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, optimizer_id
)
else:
if max_unorm > 0.0:
unorm_vec.zero_()
_optimizer_precondition_32bit(
g, p, state1, state2, unorm_vec,
beta1, beta2, eps, weight_decay, step,
lr, gnorm_scale, optimizer_id
g, p, state1, state2, unorm_vec, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, optimizer_id
)
_optimizer_update_32bit(
g, p, state1, state2, unorm_vec, max_unorm, param_norm,
beta1, beta2, beta3, alpha, eps, weight_decay, step,
lr, gnorm_scale, optimizer_id
g,
p,
state1,
state2,
unorm_vec,
max_unorm,
param_norm,
beta1,
beta2,
beta3,
alpha,
eps,
weight_decay,
step,
lr,
gnorm_scale,
optimizer_id,
)
......@@ -4,6 +4,7 @@ import torch
import triton
import triton.language as tl
# from triton.language.extra import libdevice
MOMENTUM = 0
......@@ -23,6 +24,7 @@ name2optimizer_id = {
"ademamix": ADEMAMIX,
}
@triton.jit
def _optimizer_precondition_2state_32bit(
g_ptr,
......@@ -367,34 +369,104 @@ def optimizer_update_32bit_impl(
if optimizer_name == "lion":
fn_update[grid](
g, p, state1, state2, unorm_vec, max_unorm, param_norm,
beta1, beta2, beta3, alpha, eps, weight_decay, step,
beta1_step, beta2_step, lr, gnorm_scale, skip_zeros,
p.numel(), optimizer_id, BLOCK_SIZE, N_PER_TH, num_warps=2,
g,
p,
state1,
state2,
unorm_vec,
max_unorm,
param_norm,
beta1,
beta2,
beta3,
alpha,
eps,
weight_decay,
step,
beta1_step,
beta2_step,
lr,
gnorm_scale,
skip_zeros,
p.numel(),
optimizer_id,
BLOCK_SIZE,
N_PER_TH,
num_warps=2,
)
if max_unorm > 0.0:
unorm_vec.zero_()
fn_preprocess[grid](
g, p, state1, state2, unorm_vec,
beta1, beta2, eps, weight_decay, step,
beta1_step, beta2_step, lr, gnorm_scale,
p.numel(), optimizer_id, BLOCK_SIZE, N_PER_TH, num_warps=2,
g,
p,
state1,
state2,
unorm_vec,
beta1,
beta2,
eps,
weight_decay,
step,
beta1_step,
beta2_step,
lr,
gnorm_scale,
p.numel(),
optimizer_id,
BLOCK_SIZE,
N_PER_TH,
num_warps=2,
)
else:
if max_unorm > 0.0:
unorm_vec.zero_()
fn_preprocess[grid](
g, p, state1, state2, unorm_vec,
beta1, beta2, eps, weight_decay, step,
beta1_step, beta2_step, lr, gnorm_scale,
p.numel(), optimizer_id, BLOCK_SIZE, N_PER_TH, num_warps=2,
g,
p,
state1,
state2,
unorm_vec,
beta1,
beta2,
eps,
weight_decay,
step,
beta1_step,
beta2_step,
lr,
gnorm_scale,
p.numel(),
optimizer_id,
BLOCK_SIZE,
N_PER_TH,
num_warps=2,
)
fn_update[grid](
g, p, state1, state2, unorm_vec, max_unorm, param_norm,
beta1, beta2, beta3, alpha, eps, weight_decay, step,
beta1_step, beta2_step, lr, gnorm_scale, skip_zeros,
p.numel(), optimizer_id, BLOCK_SIZE, N_PER_TH, num_warps=2,
g,
p,
state1,
state2,
unorm_vec,
max_unorm,
param_norm,
beta1,
beta2,
beta3,
alpha,
eps,
weight_decay,
step,
beta1_step,
beta2_step,
lr,
gnorm_scale,
skip_zeros,
p.numel(),
optimizer_id,
BLOCK_SIZE,
N_PER_TH,
num_warps=2,
)
......@@ -3,7 +3,7 @@ from typing import Optional
import torch
from . import triton_kernels, kernels_optim
from . import kernels_optim, triton_kernels
# currently codes unused, kept for reference
# Should be the same for quant/dequant
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment