Commit 4b025748 authored by Matthew Douglas's avatar Matthew Douglas
Browse files

Lint fix

parent 1813b058
...@@ -320,6 +320,7 @@ name2optimizer_id = { ...@@ -320,6 +320,7 @@ name2optimizer_id = {
"ademamix": ADEMAMIX, "ademamix": ADEMAMIX,
} }
@torch.compile @torch.compile
def _optimizer_precondition_32bit( def _optimizer_precondition_32bit(
g: torch.Tensor, g: torch.Tensor,
...@@ -525,29 +526,53 @@ def _( ...@@ -525,29 +526,53 @@ def _(
if optimizer_name == "lion": if optimizer_name == "lion":
_optimizer_update_32bit( _optimizer_update_32bit(
g, p, state1, state2, unorm_vec, max_unorm, param_norm, g,
beta1, beta2, beta3, alpha, eps, weight_decay, step, p,
lr, gnorm_scale, optimizer_id state1,
state2,
unorm_vec,
max_unorm,
param_norm,
beta1,
beta2,
beta3,
alpha,
eps,
weight_decay,
step,
lr,
gnorm_scale,
optimizer_id,
) )
if max_unorm > 0.0: if max_unorm > 0.0:
unorm_vec.zero_() unorm_vec.zero_()
_optimizer_precondition_32bit( _optimizer_precondition_32bit(
g, p, state1, state2, unorm_vec, g, p, state1, state2, unorm_vec, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, optimizer_id
beta1, beta2, eps, weight_decay, step,
lr, gnorm_scale, optimizer_id
) )
else: else:
if max_unorm > 0.0: if max_unorm > 0.0:
unorm_vec.zero_() unorm_vec.zero_()
_optimizer_precondition_32bit( _optimizer_precondition_32bit(
g, p, state1, state2, unorm_vec, g, p, state1, state2, unorm_vec, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, optimizer_id
beta1, beta2, eps, weight_decay, step,
lr, gnorm_scale, optimizer_id
) )
_optimizer_update_32bit( _optimizer_update_32bit(
g, p, state1, state2, unorm_vec, max_unorm, param_norm, g,
beta1, beta2, beta3, alpha, eps, weight_decay, step, p,
lr, gnorm_scale, optimizer_id state1,
state2,
unorm_vec,
max_unorm,
param_norm,
beta1,
beta2,
beta3,
alpha,
eps,
weight_decay,
step,
lr,
gnorm_scale,
optimizer_id,
) )
...@@ -4,6 +4,7 @@ import torch ...@@ -4,6 +4,7 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
# from triton.language.extra import libdevice # from triton.language.extra import libdevice
MOMENTUM = 0 MOMENTUM = 0
...@@ -23,6 +24,7 @@ name2optimizer_id = { ...@@ -23,6 +24,7 @@ name2optimizer_id = {
"ademamix": ADEMAMIX, "ademamix": ADEMAMIX,
} }
@triton.jit @triton.jit
def _optimizer_precondition_2state_32bit( def _optimizer_precondition_2state_32bit(
g_ptr, g_ptr,
...@@ -49,32 +51,32 @@ def _optimizer_precondition_2state_32bit( ...@@ -49,32 +51,32 @@ def _optimizer_precondition_2state_32bit(
block_start_idx = pid * N_PER_TH block_start_idx = pid * N_PER_TH
offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH) offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH)
mask = offsets < n_elements mask = offsets < n_elements
g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0) g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0)
s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0) s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0)
s2_vals = tl.load(state2_ptr + offsets, mask=mask, other=0.0) s2_vals = tl.load(state2_ptr + offsets, mask=mask, other=0.0)
g_vals = gnorm_scale * g_vals g_vals = gnorm_scale * g_vals
correction1 = 1.0 / (1.0 - beta1_step) correction1 = 1.0 / (1.0 - beta1_step)
correction2 = 1.0 / (1.0 - beta2_step) correction2 = 1.0 / (1.0 - beta2_step)
if OPTIMIZER_ID == 3: # ADAM if OPTIMIZER_ID == 3: # ADAM
s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals
s2_vals = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals s2_vals = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals
s1_vals = s1_vals * correction1 s1_vals = s1_vals * correction1
s2_vals = s2_vals * correction2 s2_vals = s2_vals * correction2
update_vals = s1_vals / (tl.sqrt(s2_vals) + eps) update_vals = s1_vals / (tl.sqrt(s2_vals) + eps)
update_norm = update_vals * update_vals update_norm = update_vals * update_vals
elif OPTIMIZER_ID == 5: # ADEMAMIX elif OPTIMIZER_ID == 5: # ADEMAMIX
update_norm = s1_vals update_norm = s1_vals
total_norm = tl.sum(tl.where(mask, update_norm, 0.0)) total_norm = tl.sum(tl.where(mask, update_norm, 0.0))
tl.atomic_add(unorm_ptr, total_norm) tl.atomic_add(unorm_ptr, total_norm)
...@@ -89,7 +91,7 @@ def _optimizer_precondition_1state_32bit( ...@@ -89,7 +91,7 @@ def _optimizer_precondition_1state_32bit(
beta2: tl.constexpr, beta2: tl.constexpr,
eps: tl.constexpr, eps: tl.constexpr,
weight_decay, weight_decay,
step, step,
beta1_step, beta1_step,
beta2_step, beta2_step,
lr, lr,
...@@ -104,12 +106,12 @@ def _optimizer_precondition_1state_32bit( ...@@ -104,12 +106,12 @@ def _optimizer_precondition_1state_32bit(
block_start_idx = pid * N_PER_TH block_start_idx = pid * N_PER_TH
offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH) offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH)
mask = offsets < n_elements mask = offsets < n_elements
g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0) g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0)
s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0) s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0)
g_vals = gnorm_scale * g_vals g_vals = gnorm_scale * g_vals
if OPTIMIZER_ID == 0: # MOMENTUM if OPTIMIZER_ID == 0: # MOMENTUM
if step == 1: if step == 1:
s1_vals = g_vals s1_vals = g_vals
...@@ -130,9 +132,9 @@ def _optimizer_precondition_1state_32bit( ...@@ -130,9 +132,9 @@ def _optimizer_precondition_1state_32bit(
s1_vals = s1_vals + g_vals * g_vals s1_vals = s1_vals + g_vals * g_vals
update_vals = g_vals / (tl.sqrt(s1_vals) + eps) update_vals = g_vals / (tl.sqrt(s1_vals) + eps)
update_norm = update_vals * update_vals update_norm = update_vals * update_vals
total_norm = tl.sum(tl.where(mask, update_norm, 0.0)) total_norm = tl.sum(tl.where(mask, update_norm, 0.0))
tl.atomic_add(unorm_ptr, total_norm) tl.atomic_add(unorm_ptr, total_norm)
...@@ -151,7 +153,7 @@ def _optimizer_update_2state_32bit_triton_kernel( ...@@ -151,7 +153,7 @@ def _optimizer_update_2state_32bit_triton_kernel(
alpha, alpha,
eps: tl.constexpr, eps: tl.constexpr,
weight_decay: tl.constexpr, weight_decay: tl.constexpr,
step, step,
beta1_step, beta1_step,
beta2_step, beta2_step,
lr, lr,
...@@ -167,23 +169,23 @@ def _optimizer_update_2state_32bit_triton_kernel( ...@@ -167,23 +169,23 @@ def _optimizer_update_2state_32bit_triton_kernel(
block_start_idx = pid * N_PER_TH block_start_idx = pid * N_PER_TH
offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH) offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH)
mask = offsets < n_elements mask = offsets < n_elements
g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32) g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
p_vals = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32) p_vals = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0) s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0)
s2_vals = tl.load(state2_ptr + offsets, mask=mask, other=0.0) s2_vals = tl.load(state2_ptr + offsets, mask=mask, other=0.0)
if OPTIMIZER_ID == 5: # ADEMAMIX if OPTIMIZER_ID == 5: # ADEMAMIX
s3_vals = tl.load(state1_ptr + n_elements + offsets, mask=mask, other=0.0) s3_vals = tl.load(state1_ptr + n_elements + offsets, mask=mask, other=0.0)
g_vals = gnorm_scale * g_vals g_vals = gnorm_scale * g_vals
update_scale = 1.0 update_scale = 1.0
if max_unorm > 0.0: if max_unorm > 0.0:
current_unorm = tl.sqrt(tl.load(unorm_ptr)) current_unorm = tl.sqrt(tl.load(unorm_ptr))
if current_unorm > max_unorm * param_norm: if current_unorm > max_unorm * param_norm:
update_scale = (max_unorm * param_norm) / current_unorm update_scale = (max_unorm * param_norm) / current_unorm
if OPTIMIZER_ID == 3: # ADAM if OPTIMIZER_ID == 3: # ADAM
s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals
s2_vals = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals s2_vals = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals
...@@ -197,8 +199,8 @@ def _optimizer_update_2state_32bit_triton_kernel( ...@@ -197,8 +199,8 @@ def _optimizer_update_2state_32bit_triton_kernel(
update_val = update_scale * step_size * (s1_vals / (tl.sqrt(s2_vals) + eps * correction2)) update_val = update_scale * step_size * (s1_vals / (tl.sqrt(s2_vals) + eps * correction2))
p_vals = p_vals + update_val p_vals = p_vals + update_val
elif OPTIMIZER_ID == 5: # ADEMAMIX elif OPTIMIZER_ID == 5: # ADEMAMIX
s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals # m1 s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals # m1
s3_vals = s3_vals * beta3 + (1.0 - beta3) * g_vals # m2 s3_vals = s3_vals * beta3 + (1.0 - beta3) * g_vals # m2
s2_vals = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals # nu s2_vals = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals # nu
...@@ -208,15 +210,15 @@ def _optimizer_update_2state_32bit_triton_kernel( ...@@ -208,15 +210,15 @@ def _optimizer_update_2state_32bit_triton_kernel(
if weight_decay > 0.0: if weight_decay > 0.0:
p_vals = p_vals * (1.0 - lr * weight_decay) p_vals = p_vals * (1.0 - lr * weight_decay)
mixed_momentum = (s1_vals / correction1) + (alpha * s3_vals) mixed_momentum = (s1_vals / correction1) + (alpha * s3_vals)
adaptive_term = (tl.sqrt(s2_vals) / correction2) + eps adaptive_term = (tl.sqrt(s2_vals) / correction2) + eps
p_vals = p_vals - lr * (mixed_momentum / adaptive_term) p_vals = p_vals - lr * (mixed_momentum / adaptive_term)
tl.store(p_ptr + offsets, p_vals, mask=mask) tl.store(p_ptr + offsets, p_vals, mask=mask)
tl.store(state1_ptr + offsets, s1_vals, mask=mask) tl.store(state1_ptr + offsets, s1_vals, mask=mask)
tl.store(state2_ptr + offsets, s2_vals, mask=mask) tl.store(state2_ptr + offsets, s2_vals, mask=mask)
if OPTIMIZER_ID == 5: # ADEMAMIX if OPTIMIZER_ID == 5: # ADEMAMIX
tl.store(state1_ptr + n_elements + offsets, s3_vals, mask=mask) tl.store(state1_ptr + n_elements + offsets, s3_vals, mask=mask)
...@@ -224,7 +226,7 @@ def _optimizer_update_2state_32bit_triton_kernel( ...@@ -224,7 +226,7 @@ def _optimizer_update_2state_32bit_triton_kernel(
@triton.jit @triton.jit
def _optimizer_update_1state_32bit_triton_kernel( def _optimizer_update_1state_32bit_triton_kernel(
g_ptr, g_ptr,
p_ptr, p_ptr,
state1_ptr, state1_ptr,
state2_ptr, state2_ptr,
unorm_ptr, unorm_ptr,
...@@ -252,7 +254,7 @@ def _optimizer_update_1state_32bit_triton_kernel( ...@@ -252,7 +254,7 @@ def _optimizer_update_1state_32bit_triton_kernel(
block_start_idx = pid * N_PER_TH block_start_idx = pid * N_PER_TH
offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH) offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH)
mask = offsets < n_elements mask = offsets < n_elements
g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32) g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
p_vals = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32) p_vals = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0) s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0)
...@@ -260,19 +262,19 @@ def _optimizer_update_1state_32bit_triton_kernel( ...@@ -260,19 +262,19 @@ def _optimizer_update_1state_32bit_triton_kernel(
g_vals = gnorm_scale * g_vals g_vals = gnorm_scale * g_vals
if weight_decay > 0.0: if weight_decay > 0.0:
g_vals = g_vals + p_vals * weight_decay g_vals = g_vals + p_vals * weight_decay
update_scale = 1.0 update_scale = 1.0
if max_unorm > 0.0: if max_unorm > 0.0:
current_unorm = tl.sqrt(tl.load(unorm_ptr)) current_unorm = tl.sqrt(tl.load(unorm_ptr))
if current_unorm > max_unorm * param_norm + eps: if current_unorm > max_unorm * param_norm + eps:
update_scale = (max_unorm * param_norm + eps) / current_unorm update_scale = (max_unorm * param_norm + eps) / current_unorm
if OPTIMIZER_ID == 0: # MOMENTUM if OPTIMIZER_ID == 0: # MOMENTUM
if step == 1: if step == 1:
s1_vals = g_vals s1_vals = g_vals
else: else:
s1_vals = s1_vals * beta1 + g_vals s1_vals = s1_vals * beta1 + g_vals
update_val = update_scale * (-lr * s1_vals) update_val = update_scale * (-lr * s1_vals)
p_vals = p_vals + update_val p_vals = p_vals + update_val
...@@ -280,21 +282,21 @@ def _optimizer_update_1state_32bit_triton_kernel( ...@@ -280,21 +282,21 @@ def _optimizer_update_1state_32bit_triton_kernel(
momentum_update = s1_vals * beta1 + (1.0 - beta1) * g_vals momentum_update = s1_vals * beta1 + (1.0 - beta1) * g_vals
update_val = update_scale * lr * tl.where(momentum_update > 0, 1.0, tl.where(momentum_update < 0, -1.0, 0.0)) update_val = update_scale * lr * tl.where(momentum_update > 0, 1.0, tl.where(momentum_update < 0, -1.0, 0.0))
p_vals = p_vals - update_val p_vals = p_vals - update_val
s1_vals = s1_vals * beta2 + (1.0 - beta2) * g_vals s1_vals = s1_vals * beta2 + (1.0 - beta2) * g_vals
elif OPTIMIZER_ID == 1: # RMSPROP elif OPTIMIZER_ID == 1: # RMSPROP
s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals * g_vals s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals * g_vals
update_val = update_scale * lr * g_vals / (tl.sqrt(s1_vals) + eps) update_val = update_scale * lr * g_vals / (tl.sqrt(s1_vals) + eps)
p_vals = p_vals - update_val p_vals = p_vals - update_val
elif OPTIMIZER_ID == 2: # ADAGRAD elif OPTIMIZER_ID == 2: # ADAGRAD
s1_vals = s1_vals + g_vals * g_vals s1_vals = s1_vals + g_vals * g_vals
update_val = lr * g_vals / (tl.sqrt(s1_vals) + eps) update_val = lr * g_vals / (tl.sqrt(s1_vals) + eps)
p_vals = p_vals - update_val p_vals = p_vals - update_val
tl.store(p_ptr + offsets, p_vals, mask=mask) tl.store(p_ptr + offsets, p_vals, mask=mask)
tl.store(state1_ptr + offsets, s1_vals, mask=mask) tl.store(state1_ptr + offsets, s1_vals, mask=mask)
...@@ -367,34 +369,104 @@ def optimizer_update_32bit_impl( ...@@ -367,34 +369,104 @@ def optimizer_update_32bit_impl(
if optimizer_name == "lion": if optimizer_name == "lion":
fn_update[grid]( fn_update[grid](
g, p, state1, state2, unorm_vec, max_unorm, param_norm, g,
beta1, beta2, beta3, alpha, eps, weight_decay, step, p,
beta1_step, beta2_step, lr, gnorm_scale, skip_zeros, state1,
p.numel(), optimizer_id, BLOCK_SIZE, N_PER_TH, num_warps=2, state2,
unorm_vec,
max_unorm,
param_norm,
beta1,
beta2,
beta3,
alpha,
eps,
weight_decay,
step,
beta1_step,
beta2_step,
lr,
gnorm_scale,
skip_zeros,
p.numel(),
optimizer_id,
BLOCK_SIZE,
N_PER_TH,
num_warps=2,
) )
if max_unorm > 0.0: if max_unorm > 0.0:
unorm_vec.zero_() unorm_vec.zero_()
fn_preprocess[grid]( fn_preprocess[grid](
g, p, state1, state2, unorm_vec, g,
beta1, beta2, eps, weight_decay, step, p,
beta1_step, beta2_step, lr, gnorm_scale, state1,
p.numel(), optimizer_id, BLOCK_SIZE, N_PER_TH, num_warps=2, state2,
unorm_vec,
beta1,
beta2,
eps,
weight_decay,
step,
beta1_step,
beta2_step,
lr,
gnorm_scale,
p.numel(),
optimizer_id,
BLOCK_SIZE,
N_PER_TH,
num_warps=2,
) )
else: else:
if max_unorm > 0.0: if max_unorm > 0.0:
unorm_vec.zero_() unorm_vec.zero_()
fn_preprocess[grid]( fn_preprocess[grid](
g, p, state1, state2, unorm_vec, g,
beta1, beta2, eps, weight_decay, step, p,
beta1_step, beta2_step, lr, gnorm_scale, state1,
p.numel(), optimizer_id, BLOCK_SIZE, N_PER_TH, num_warps=2, state2,
unorm_vec,
beta1,
beta2,
eps,
weight_decay,
step,
beta1_step,
beta2_step,
lr,
gnorm_scale,
p.numel(),
optimizer_id,
BLOCK_SIZE,
N_PER_TH,
num_warps=2,
) )
fn_update[grid]( fn_update[grid](
g, p, state1, state2, unorm_vec, max_unorm, param_norm, g,
beta1, beta2, beta3, alpha, eps, weight_decay, step, p,
beta1_step, beta2_step, lr, gnorm_scale, skip_zeros, state1,
p.numel(), optimizer_id, BLOCK_SIZE, N_PER_TH, num_warps=2, state2,
unorm_vec,
max_unorm,
param_norm,
beta1,
beta2,
beta3,
alpha,
eps,
weight_decay,
step,
beta1_step,
beta2_step,
lr,
gnorm_scale,
skip_zeros,
p.numel(),
optimizer_id,
BLOCK_SIZE,
N_PER_TH,
num_warps=2,
) )
...@@ -3,7 +3,7 @@ from typing import Optional ...@@ -3,7 +3,7 @@ from typing import Optional
import torch import torch
from . import triton_kernels, kernels_optim from . import kernels_optim, triton_kernels
# currently codes unused, kept for reference # currently codes unused, kept for reference
# Should be the same for quant/dequant # Should be the same for quant/dequant
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment