Commit 4b025748 authored by Matthew Douglas's avatar Matthew Douglas
Browse files

Lint fix

parent 1813b058
...@@ -320,6 +320,7 @@ name2optimizer_id = { ...@@ -320,6 +320,7 @@ name2optimizer_id = {
"ademamix": ADEMAMIX, "ademamix": ADEMAMIX,
} }
@torch.compile @torch.compile
def _optimizer_precondition_32bit( def _optimizer_precondition_32bit(
g: torch.Tensor, g: torch.Tensor,
...@@ -525,29 +526,53 @@ def _( ...@@ -525,29 +526,53 @@ def _(
if optimizer_name == "lion": if optimizer_name == "lion":
_optimizer_update_32bit( _optimizer_update_32bit(
g, p, state1, state2, unorm_vec, max_unorm, param_norm, g,
beta1, beta2, beta3, alpha, eps, weight_decay, step, p,
lr, gnorm_scale, optimizer_id state1,
state2,
unorm_vec,
max_unorm,
param_norm,
beta1,
beta2,
beta3,
alpha,
eps,
weight_decay,
step,
lr,
gnorm_scale,
optimizer_id,
) )
if max_unorm > 0.0: if max_unorm > 0.0:
unorm_vec.zero_() unorm_vec.zero_()
_optimizer_precondition_32bit( _optimizer_precondition_32bit(
g, p, state1, state2, unorm_vec, g, p, state1, state2, unorm_vec, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, optimizer_id
beta1, beta2, eps, weight_decay, step,
lr, gnorm_scale, optimizer_id
) )
else: else:
if max_unorm > 0.0: if max_unorm > 0.0:
unorm_vec.zero_() unorm_vec.zero_()
_optimizer_precondition_32bit( _optimizer_precondition_32bit(
g, p, state1, state2, unorm_vec, g, p, state1, state2, unorm_vec, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, optimizer_id
beta1, beta2, eps, weight_decay, step,
lr, gnorm_scale, optimizer_id
) )
_optimizer_update_32bit( _optimizer_update_32bit(
g, p, state1, state2, unorm_vec, max_unorm, param_norm, g,
beta1, beta2, beta3, alpha, eps, weight_decay, step, p,
lr, gnorm_scale, optimizer_id state1,
state2,
unorm_vec,
max_unorm,
param_norm,
beta1,
beta2,
beta3,
alpha,
eps,
weight_decay,
step,
lr,
gnorm_scale,
optimizer_id,
) )
...@@ -4,6 +4,7 @@ import torch ...@@ -4,6 +4,7 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
# from triton.language.extra import libdevice # from triton.language.extra import libdevice
MOMENTUM = 0 MOMENTUM = 0
...@@ -23,6 +24,7 @@ name2optimizer_id = { ...@@ -23,6 +24,7 @@ name2optimizer_id = {
"ademamix": ADEMAMIX, "ademamix": ADEMAMIX,
} }
@triton.jit @triton.jit
def _optimizer_precondition_2state_32bit( def _optimizer_precondition_2state_32bit(
g_ptr, g_ptr,
...@@ -367,34 +369,104 @@ def optimizer_update_32bit_impl( ...@@ -367,34 +369,104 @@ def optimizer_update_32bit_impl(
if optimizer_name == "lion": if optimizer_name == "lion":
fn_update[grid]( fn_update[grid](
g, p, state1, state2, unorm_vec, max_unorm, param_norm, g,
beta1, beta2, beta3, alpha, eps, weight_decay, step, p,
beta1_step, beta2_step, lr, gnorm_scale, skip_zeros, state1,
p.numel(), optimizer_id, BLOCK_SIZE, N_PER_TH, num_warps=2, state2,
unorm_vec,
max_unorm,
param_norm,
beta1,
beta2,
beta3,
alpha,
eps,
weight_decay,
step,
beta1_step,
beta2_step,
lr,
gnorm_scale,
skip_zeros,
p.numel(),
optimizer_id,
BLOCK_SIZE,
N_PER_TH,
num_warps=2,
) )
if max_unorm > 0.0: if max_unorm > 0.0:
unorm_vec.zero_() unorm_vec.zero_()
fn_preprocess[grid]( fn_preprocess[grid](
g, p, state1, state2, unorm_vec, g,
beta1, beta2, eps, weight_decay, step, p,
beta1_step, beta2_step, lr, gnorm_scale, state1,
p.numel(), optimizer_id, BLOCK_SIZE, N_PER_TH, num_warps=2, state2,
unorm_vec,
beta1,
beta2,
eps,
weight_decay,
step,
beta1_step,
beta2_step,
lr,
gnorm_scale,
p.numel(),
optimizer_id,
BLOCK_SIZE,
N_PER_TH,
num_warps=2,
) )
else: else:
if max_unorm > 0.0: if max_unorm > 0.0:
unorm_vec.zero_() unorm_vec.zero_()
fn_preprocess[grid]( fn_preprocess[grid](
g, p, state1, state2, unorm_vec, g,
beta1, beta2, eps, weight_decay, step, p,
beta1_step, beta2_step, lr, gnorm_scale, state1,
p.numel(), optimizer_id, BLOCK_SIZE, N_PER_TH, num_warps=2, state2,
unorm_vec,
beta1,
beta2,
eps,
weight_decay,
step,
beta1_step,
beta2_step,
lr,
gnorm_scale,
p.numel(),
optimizer_id,
BLOCK_SIZE,
N_PER_TH,
num_warps=2,
) )
fn_update[grid]( fn_update[grid](
g, p, state1, state2, unorm_vec, max_unorm, param_norm, g,
beta1, beta2, beta3, alpha, eps, weight_decay, step, p,
beta1_step, beta2_step, lr, gnorm_scale, skip_zeros, state1,
p.numel(), optimizer_id, BLOCK_SIZE, N_PER_TH, num_warps=2, state2,
unorm_vec,
max_unorm,
param_norm,
beta1,
beta2,
beta3,
alpha,
eps,
weight_decay,
step,
beta1_step,
beta2_step,
lr,
gnorm_scale,
skip_zeros,
p.numel(),
optimizer_id,
BLOCK_SIZE,
N_PER_TH,
num_warps=2,
) )
...@@ -3,7 +3,7 @@ from typing import Optional ...@@ -3,7 +3,7 @@ from typing import Optional
import torch import torch
from . import triton_kernels, kernels_optim from . import kernels_optim, triton_kernels
# currently codes unused, kept for reference # currently codes unused, kept for reference
# Should be the same for quant/dequant # Should be the same for quant/dequant
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment