Commit 4b025748 authored by Matthew Douglas's avatar Matthew Douglas
Browse files

Lint fix

parent 1813b058
......@@ -320,6 +320,7 @@ name2optimizer_id = {
"ademamix": ADEMAMIX,
}
@torch.compile
def _optimizer_precondition_32bit(
g: torch.Tensor,
......@@ -525,29 +526,53 @@ def _(
if optimizer_name == "lion":
_optimizer_update_32bit(
g, p, state1, state2, unorm_vec, max_unorm, param_norm,
beta1, beta2, beta3, alpha, eps, weight_decay, step,
lr, gnorm_scale, optimizer_id
g,
p,
state1,
state2,
unorm_vec,
max_unorm,
param_norm,
beta1,
beta2,
beta3,
alpha,
eps,
weight_decay,
step,
lr,
gnorm_scale,
optimizer_id,
)
if max_unorm > 0.0:
unorm_vec.zero_()
_optimizer_precondition_32bit(
g, p, state1, state2, unorm_vec,
beta1, beta2, eps, weight_decay, step,
lr, gnorm_scale, optimizer_id
g, p, state1, state2, unorm_vec, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, optimizer_id
)
else:
if max_unorm > 0.0:
unorm_vec.zero_()
_optimizer_precondition_32bit(
g, p, state1, state2, unorm_vec,
beta1, beta2, eps, weight_decay, step,
lr, gnorm_scale, optimizer_id
g, p, state1, state2, unorm_vec, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, optimizer_id
)
_optimizer_update_32bit(
g, p, state1, state2, unorm_vec, max_unorm, param_norm,
beta1, beta2, beta3, alpha, eps, weight_decay, step,
lr, gnorm_scale, optimizer_id
g,
p,
state1,
state2,
unorm_vec,
max_unorm,
param_norm,
beta1,
beta2,
beta3,
alpha,
eps,
weight_decay,
step,
lr,
gnorm_scale,
optimizer_id,
)
......@@ -4,6 +4,7 @@ import torch
import triton
import triton.language as tl
# from triton.language.extra import libdevice
MOMENTUM = 0
......@@ -23,6 +24,7 @@ name2optimizer_id = {
"ademamix": ADEMAMIX,
}
@triton.jit
def _optimizer_precondition_2state_32bit(
g_ptr,
......@@ -49,32 +51,32 @@ def _optimizer_precondition_2state_32bit(
block_start_idx = pid * N_PER_TH
offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH)
mask = offsets < n_elements
g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0)
s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0)
s2_vals = tl.load(state2_ptr + offsets, mask=mask, other=0.0)
g_vals = gnorm_scale * g_vals
correction1 = 1.0 / (1.0 - beta1_step)
correction2 = 1.0 / (1.0 - beta2_step)
if OPTIMIZER_ID == 3: # ADAM
if OPTIMIZER_ID == 3: # ADAM
s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals
s2_vals = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals
s1_vals = s1_vals * correction1
s2_vals = s2_vals * correction2
update_vals = s1_vals / (tl.sqrt(s2_vals) + eps)
update_norm = update_vals * update_vals
elif OPTIMIZER_ID == 5: # ADEMAMIX
elif OPTIMIZER_ID == 5: # ADEMAMIX
update_norm = s1_vals
total_norm = tl.sum(tl.where(mask, update_norm, 0.0))
tl.atomic_add(unorm_ptr, total_norm)
......@@ -89,7 +91,7 @@ def _optimizer_precondition_1state_32bit(
beta2: tl.constexpr,
eps: tl.constexpr,
weight_decay,
step,
step,
beta1_step,
beta2_step,
lr,
......@@ -104,12 +106,12 @@ def _optimizer_precondition_1state_32bit(
block_start_idx = pid * N_PER_TH
offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH)
mask = offsets < n_elements
g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0)
s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0)
g_vals = gnorm_scale * g_vals
if OPTIMIZER_ID == 0: # MOMENTUM
if step == 1:
s1_vals = g_vals
......@@ -130,9 +132,9 @@ def _optimizer_precondition_1state_32bit(
s1_vals = s1_vals + g_vals * g_vals
update_vals = g_vals / (tl.sqrt(s1_vals) + eps)
update_norm = update_vals * update_vals
total_norm = tl.sum(tl.where(mask, update_norm, 0.0))
tl.atomic_add(unorm_ptr, total_norm)
......@@ -151,7 +153,7 @@ def _optimizer_update_2state_32bit_triton_kernel(
alpha,
eps: tl.constexpr,
weight_decay: tl.constexpr,
step,
step,
beta1_step,
beta2_step,
lr,
......@@ -167,23 +169,23 @@ def _optimizer_update_2state_32bit_triton_kernel(
block_start_idx = pid * N_PER_TH
offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH)
mask = offsets < n_elements
g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
p_vals = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0)
s2_vals = tl.load(state2_ptr + offsets, mask=mask, other=0.0)
if OPTIMIZER_ID == 5: # ADEMAMIX
s3_vals = tl.load(state1_ptr + n_elements + offsets, mask=mask, other=0.0)
g_vals = gnorm_scale * g_vals
update_scale = 1.0
if max_unorm > 0.0:
current_unorm = tl.sqrt(tl.load(unorm_ptr))
if current_unorm > max_unorm * param_norm:
update_scale = (max_unorm * param_norm) / current_unorm
if OPTIMIZER_ID == 3: # ADAM
s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals
s2_vals = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals
......@@ -197,8 +199,8 @@ def _optimizer_update_2state_32bit_triton_kernel(
update_val = update_scale * step_size * (s1_vals / (tl.sqrt(s2_vals) + eps * correction2))
p_vals = p_vals + update_val
elif OPTIMIZER_ID == 5: # ADEMAMIX
elif OPTIMIZER_ID == 5: # ADEMAMIX
s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals # m1
s3_vals = s3_vals * beta3 + (1.0 - beta3) * g_vals # m2
s2_vals = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals # nu
......@@ -208,15 +210,15 @@ def _optimizer_update_2state_32bit_triton_kernel(
if weight_decay > 0.0:
p_vals = p_vals * (1.0 - lr * weight_decay)
mixed_momentum = (s1_vals / correction1) + (alpha * s3_vals)
adaptive_term = (tl.sqrt(s2_vals) / correction2) + eps
p_vals = p_vals - lr * (mixed_momentum / adaptive_term)
tl.store(p_ptr + offsets, p_vals, mask=mask)
tl.store(state1_ptr + offsets, s1_vals, mask=mask)
tl.store(state2_ptr + offsets, s2_vals, mask=mask)
if OPTIMIZER_ID == 5: # ADEMAMIX
tl.store(state1_ptr + n_elements + offsets, s3_vals, mask=mask)
......@@ -224,7 +226,7 @@ def _optimizer_update_2state_32bit_triton_kernel(
@triton.jit
def _optimizer_update_1state_32bit_triton_kernel(
g_ptr,
p_ptr,
p_ptr,
state1_ptr,
state2_ptr,
unorm_ptr,
......@@ -252,7 +254,7 @@ def _optimizer_update_1state_32bit_triton_kernel(
block_start_idx = pid * N_PER_TH
offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH)
mask = offsets < n_elements
g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
p_vals = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0)
......@@ -260,19 +262,19 @@ def _optimizer_update_1state_32bit_triton_kernel(
g_vals = gnorm_scale * g_vals
if weight_decay > 0.0:
g_vals = g_vals + p_vals * weight_decay
update_scale = 1.0
if max_unorm > 0.0:
current_unorm = tl.sqrt(tl.load(unorm_ptr))
if current_unorm > max_unorm * param_norm + eps:
update_scale = (max_unorm * param_norm + eps) / current_unorm
if OPTIMIZER_ID == 0: # MOMENTUM
if step == 1:
s1_vals = g_vals
else:
s1_vals = s1_vals * beta1 + g_vals
update_val = update_scale * (-lr * s1_vals)
p_vals = p_vals + update_val
......@@ -280,21 +282,21 @@ def _optimizer_update_1state_32bit_triton_kernel(
momentum_update = s1_vals * beta1 + (1.0 - beta1) * g_vals
update_val = update_scale * lr * tl.where(momentum_update > 0, 1.0, tl.where(momentum_update < 0, -1.0, 0.0))
p_vals = p_vals - update_val
s1_vals = s1_vals * beta2 + (1.0 - beta2) * g_vals
elif OPTIMIZER_ID == 1: # RMSPROP
s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals * g_vals
update_val = update_scale * lr * g_vals / (tl.sqrt(s1_vals) + eps)
p_vals = p_vals - update_val
elif OPTIMIZER_ID == 2: # ADAGRAD
s1_vals = s1_vals + g_vals * g_vals
update_val = lr * g_vals / (tl.sqrt(s1_vals) + eps)
p_vals = p_vals - update_val
tl.store(p_ptr + offsets, p_vals, mask=mask)
tl.store(state1_ptr + offsets, s1_vals, mask=mask)
......@@ -367,34 +369,104 @@ def optimizer_update_32bit_impl(
if optimizer_name == "lion":
fn_update[grid](
g, p, state1, state2, unorm_vec, max_unorm, param_norm,
beta1, beta2, beta3, alpha, eps, weight_decay, step,
beta1_step, beta2_step, lr, gnorm_scale, skip_zeros,
p.numel(), optimizer_id, BLOCK_SIZE, N_PER_TH, num_warps=2,
g,
p,
state1,
state2,
unorm_vec,
max_unorm,
param_norm,
beta1,
beta2,
beta3,
alpha,
eps,
weight_decay,
step,
beta1_step,
beta2_step,
lr,
gnorm_scale,
skip_zeros,
p.numel(),
optimizer_id,
BLOCK_SIZE,
N_PER_TH,
num_warps=2,
)
if max_unorm > 0.0:
unorm_vec.zero_()
fn_preprocess[grid](
g, p, state1, state2, unorm_vec,
beta1, beta2, eps, weight_decay, step,
beta1_step, beta2_step, lr, gnorm_scale,
p.numel(), optimizer_id, BLOCK_SIZE, N_PER_TH, num_warps=2,
g,
p,
state1,
state2,
unorm_vec,
beta1,
beta2,
eps,
weight_decay,
step,
beta1_step,
beta2_step,
lr,
gnorm_scale,
p.numel(),
optimizer_id,
BLOCK_SIZE,
N_PER_TH,
num_warps=2,
)
else:
if max_unorm > 0.0:
unorm_vec.zero_()
fn_preprocess[grid](
g, p, state1, state2, unorm_vec,
beta1, beta2, eps, weight_decay, step,
beta1_step, beta2_step, lr, gnorm_scale,
p.numel(), optimizer_id, BLOCK_SIZE, N_PER_TH, num_warps=2,
g,
p,
state1,
state2,
unorm_vec,
beta1,
beta2,
eps,
weight_decay,
step,
beta1_step,
beta2_step,
lr,
gnorm_scale,
p.numel(),
optimizer_id,
BLOCK_SIZE,
N_PER_TH,
num_warps=2,
)
fn_update[grid](
g, p, state1, state2, unorm_vec, max_unorm, param_norm,
beta1, beta2, beta3, alpha, eps, weight_decay, step,
beta1_step, beta2_step, lr, gnorm_scale, skip_zeros,
p.numel(), optimizer_id, BLOCK_SIZE, N_PER_TH, num_warps=2,
g,
p,
state1,
state2,
unorm_vec,
max_unorm,
param_norm,
beta1,
beta2,
beta3,
alpha,
eps,
weight_decay,
step,
beta1_step,
beta2_step,
lr,
gnorm_scale,
skip_zeros,
p.numel(),
optimizer_id,
BLOCK_SIZE,
N_PER_TH,
num_warps=2,
)
......@@ -3,7 +3,7 @@ from typing import Optional
import torch
from . import triton_kernels, kernels_optim
from . import kernels_optim, triton_kernels
# currently codes unused, kept for reference
# Should be the same for quant/dequant
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment