Unverified Commit 275671be authored by YangKai0616's avatar YangKai0616 Committed by GitHub
Browse files

[XPU] Implemented 32bit optimizers in triton (#1710)



* Implemented 32bit optimizers in triton

* Modify Comments

* Optimizing pure torch implementation

* Restore the order of parameters and modify the position of pure pytorch implementation

* Restore files permissions

---------
Co-authored-by: default avatarFanli Lin <fanli.lin@intel.com>
parent d848d4db
from collections.abc import Sequence
from math import prod
from math import prod, sqrt
from typing import Optional
import torch
......@@ -301,3 +301,253 @@ def _(
B_dq,
bias=None,
)
MOMENTUM = 0
RMSPROP = 1
ADAGRAD = 2
ADAM = 3
# LION should be larger than MOMENTUM, RMSPROP, ADAGRAD due to comparison in kernels
LION = 4
ADEMAMIX = 5
name2optimizer_id = {
"momentum": MOMENTUM,
"rmsprop": RMSPROP,
"adagrad": ADAGRAD,
"adam": ADAM,
"lion": LION,
"ademamix": ADEMAMIX,
}
@torch.compile
def _optimizer_precondition_32bit(
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
unorm_vec: torch.Tensor,
beta1: float,
beta2: float,
eps: float,
weight_decay: float,
step: int,
lr: float,
gnorm_scale: float,
optimizer_id: int,
):
"""Preprocessing optimizer, computing update norm"""
g_vals = gnorm_scale * g
if optimizer_id == 3: # ADAM
correction1 = 1.0 / (1.0 - beta1**step)
correction2 = 1.0 / (1.0 - beta2**step)
s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals
s2_vals = state2 * beta2 + (1.0 - beta2) * g_vals * g_vals
s1_vals = s1_vals * correction1
s2_vals = s2_vals * correction2
update_vals = s1_vals / (torch.sqrt(s2_vals) + eps)
update_norm = update_vals * update_vals
elif optimizer_id == 5: # ADEMAMIX
update_norm = state1
elif optimizer_id == 0: # MOMENTUM
if step == 1:
s1_vals = g_vals
else:
s1_vals = state1 * beta1 + g_vals
update_norm = s1_vals * s1_vals
elif optimizer_id == 4: # LION
s1_vals = state1 * beta2 + (1.0 - beta2) * g_vals
update_norm = s1_vals
elif optimizer_id == 1: # RMSPROP
s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals * g_vals
update_vals = g_vals / (torch.sqrt(s1_vals) + eps)
update_norm = update_vals * update_vals
elif optimizer_id == 2: # ADAGRAD
s1_vals = state1 + g_vals * g_vals
update_vals = g_vals / (torch.sqrt(s1_vals) + eps)
update_norm = update_vals * update_vals
total_norm = torch.sum(update_norm)
unorm_vec.add_(total_norm)
@torch.compile
def _optimizer_update_32bit(
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
unorm_vec: Optional[torch.Tensor],
max_unorm: float,
param_norm: float,
beta1: float,
beta2: float,
beta3: float,
alpha: float,
eps: float,
weight_decay: float,
step: int,
lr: float,
gnorm_scale: float,
optimizer_id: int,
):
"""Unified optimizer update kernel"""
p_vals = p.float()
g_vals = (gnorm_scale * g).float()
if optimizer_id in [0, 1, 2, 4] and weight_decay > 0.0:
g_vals = g_vals + p_vals * weight_decay
update_scale = 1.0
if max_unorm > 0.0:
current_unorm = torch.sqrt(unorm_vec)
if optimizer_id in [0, 1, 2, 4]: # 1-state optimizers
if current_unorm > max_unorm * param_norm + eps:
update_scale = (max_unorm * param_norm + eps) / current_unorm
else: # 2-state optimizers
if current_unorm > max_unorm * param_norm:
update_scale = (max_unorm * param_norm) / current_unorm
if optimizer_id == 3: # ADAM
s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals
s2_vals = state2 * beta2 + (1.0 - beta2) * g_vals * g_vals
correction1 = 1.0 - beta1**step
correction2 = sqrt(1.0 - beta2**step)
step_size = -lr * correction2 / correction1
if weight_decay > 0.0:
p_vals = p_vals * (1.0 - lr * weight_decay)
update_val = update_scale * step_size * (s1_vals / (torch.sqrt(s2_vals) + eps * correction2))
p_vals = p_vals + update_val
state1.copy_(s1_vals)
state2.copy_(s2_vals)
elif optimizer_id == 5: # ADEMAMIX
s1_vals = state1[0]
s3_vals = state1[1]
s2_vals = state2
m1 = s1_vals * beta1 + (1.0 - beta1) * g_vals
m2 = s3_vals * beta3 + (1.0 - beta3) * g_vals
nu = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals
correction1 = 1.0 - beta1**step
correction2 = sqrt(1.0 - beta2**step)
if weight_decay > 0.0:
p_vals = p_vals * (1.0 - lr * weight_decay)
mixed_momentum = (m1 / correction1) + (alpha * m2)
adaptive_term = (torch.sqrt(nu) / correction2) + eps
p_vals = p_vals - lr * (mixed_momentum / adaptive_term)
state1[0].copy_(m1)
state1[1].copy_(m2)
state2.copy_(nu)
elif optimizer_id == 0: # MOMENTUM
if step == 1:
s1_vals = g_vals
else:
s1_vals = state1 * beta1 + g_vals
update_val = update_scale * (-lr * s1_vals)
p_vals = p_vals + update_val
state1.copy_(s1_vals)
elif optimizer_id == 4: # LION
momentum_update = state1 * beta1 + (1.0 - beta1) * g_vals
update_val = update_scale * lr * torch.sign(momentum_update)
p_vals = p_vals - update_val
s1_vals = state1 * beta2 + (1.0 - beta2) * g_vals
state1.copy_(s1_vals)
elif optimizer_id == 1: # RMSPROP
s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals * g_vals
update_val = update_scale * lr * g_vals / (torch.sqrt(s1_vals) + eps)
p_vals = p_vals - update_val
state1.copy_(s1_vals)
elif optimizer_id == 2: # ADAGRAD
s1_vals = state1 + g_vals * g_vals
update_val = lr * g_vals / (torch.sqrt(s1_vals) + eps)
p_vals = p_vals - update_val
state1.copy_(s1_vals)
p.copy_(p_vals)
@register_kernel("bitsandbytes::optimizer_update_32bit", "default")
def _(
optimizer_name: str,
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
unorm_vec: Optional[torch.Tensor],
max_unorm: float,
param_norm: float,
beta1: float,
beta2: float,
beta3: float,
alpha: float,
eps: float,
weight_decay: float,
step: int,
lr: float,
gnorm_scale: float = 1.0,
skip_zeros=False,
) -> None:
"""
32-bit optimizer implemented by PyTorch with @torch.compile
"""
if skip_zeros:
raise NotImplementedError("skip_zeros is not supported yet")
optimizer_id = name2optimizer_id[optimizer_name]
if optimizer_name == "lion":
_optimizer_update_32bit(
g, p, state1, state2, unorm_vec, max_unorm, param_norm,
beta1, beta2, beta3, alpha, eps, weight_decay, step,
lr, gnorm_scale, optimizer_id
)
if max_unorm > 0.0:
unorm_vec.zero_()
_optimizer_precondition_32bit(
g, p, state1, state2, unorm_vec,
beta1, beta2, eps, weight_decay, step,
lr, gnorm_scale, optimizer_id
)
else:
if max_unorm > 0.0:
unorm_vec.zero_()
_optimizer_precondition_32bit(
g, p, state1, state2, unorm_vec,
beta1, beta2, eps, weight_decay, step,
lr, gnorm_scale, optimizer_id
)
_optimizer_update_32bit(
g, p, state1, state2, unorm_vec, max_unorm, param_norm,
beta1, beta2, beta3, alpha, eps, weight_decay, step,
lr, gnorm_scale, optimizer_id
)
from typing import Optional
import torch
import triton
import triton.language as tl
# from triton.language.extra import libdevice
MOMENTUM = 0
RMSPROP = 1
ADAGRAD = 2
ADAM = 3
# LION should be larger than MOMENTUM, RMSPROP, ADAGRAD due to comparison in kernels
LION = 4
ADEMAMIX = 5
name2optimizer_id = {
"momentum": MOMENTUM,
"rmsprop": RMSPROP,
"adagrad": ADAGRAD,
"adam": ADAM,
"lion": LION,
"ademamix": ADEMAMIX,
}
@triton.jit
def _optimizer_precondition_2state_32bit(
g_ptr,
p_ptr,
state1_ptr,
state2_ptr,
unorm_ptr,
beta1: tl.constexpr,
beta2: tl.constexpr,
eps: tl.constexpr,
weight_decay: tl.constexpr,
step,
beta1_step,
beta2_step,
lr,
gnorm_scale: tl.constexpr,
n_elements,
OPTIMIZER_ID: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
N_PER_TH: tl.constexpr,
):
"""Preprocessing optimizer, computing update norm (2-state optimizer)"""
pid = tl.program_id(axis=0)
block_start_idx = pid * N_PER_TH
offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH)
mask = offsets < n_elements
g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0)
s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0)
s2_vals = tl.load(state2_ptr + offsets, mask=mask, other=0.0)
g_vals = gnorm_scale * g_vals
correction1 = 1.0 / (1.0 - beta1_step)
correction2 = 1.0 / (1.0 - beta2_step)
if OPTIMIZER_ID == 3: # ADAM
s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals
s2_vals = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals
s1_vals = s1_vals * correction1
s2_vals = s2_vals * correction2
update_vals = s1_vals / (tl.sqrt(s2_vals) + eps)
update_norm = update_vals * update_vals
elif OPTIMIZER_ID == 5: # ADEMAMIX
update_norm = s1_vals
total_norm = tl.sum(tl.where(mask, update_norm, 0.0))
tl.atomic_add(unorm_ptr, total_norm)
@triton.jit
def _optimizer_precondition_1state_32bit(
g_ptr,
p_ptr,
state1_ptr,
state2_ptr,
unorm_ptr,
beta1: tl.constexpr,
beta2: tl.constexpr,
eps: tl.constexpr,
weight_decay,
step,
beta1_step,
beta2_step,
lr,
gnorm_scale: tl.constexpr,
n_elements,
OPTIMIZER_ID: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
N_PER_TH: tl.constexpr,
):
"""Preprocessing optimizer, computing update norm (1-state optimizer)"""
pid = tl.program_id(axis=0)
block_start_idx = pid * N_PER_TH
offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH)
mask = offsets < n_elements
g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0)
s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0)
g_vals = gnorm_scale * g_vals
if OPTIMIZER_ID == 0: # MOMENTUM
if step == 1:
s1_vals = g_vals
else:
s1_vals = s1_vals * beta1 + g_vals
update_norm = s1_vals * s1_vals
elif OPTIMIZER_ID == 4: # LION
s1_vals = s1_vals * beta2 + (1.0 - beta2) * g_vals
update_norm = s1_vals
elif OPTIMIZER_ID == 1: # RMSPROP
s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals * g_vals
update_vals = g_vals / (tl.sqrt(s1_vals) + eps)
update_norm = update_vals * update_vals
elif OPTIMIZER_ID == 2: # ADAGRAD
s1_vals = s1_vals + g_vals * g_vals
update_vals = g_vals / (tl.sqrt(s1_vals) + eps)
update_norm = update_vals * update_vals
total_norm = tl.sum(tl.where(mask, update_norm, 0.0))
tl.atomic_add(unorm_ptr, total_norm)
@triton.jit
def _optimizer_update_2state_32bit_triton_kernel(
g_ptr,
p_ptr,
state1_ptr,
state2_ptr,
unorm_ptr,
max_unorm: tl.constexpr,
param_norm,
beta1: tl.constexpr,
beta2: tl.constexpr,
beta3,
alpha,
eps: tl.constexpr,
weight_decay: tl.constexpr,
step,
beta1_step,
beta2_step,
lr,
gnorm_scale: tl.constexpr,
skip_zeros,
n_elements,
OPTIMIZER_ID: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
N_PER_TH: tl.constexpr,
):
"""2-state optimizer kernel"""
pid = tl.program_id(axis=0)
block_start_idx = pid * N_PER_TH
offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH)
mask = offsets < n_elements
g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
p_vals = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0)
s2_vals = tl.load(state2_ptr + offsets, mask=mask, other=0.0)
if OPTIMIZER_ID == 5: # ADEMAMIX
s3_vals = tl.load(state1_ptr + n_elements + offsets, mask=mask, other=0.0)
g_vals = gnorm_scale * g_vals
update_scale = 1.0
if max_unorm > 0.0:
current_unorm = tl.sqrt(tl.load(unorm_ptr))
if current_unorm > max_unorm * param_norm:
update_scale = (max_unorm * param_norm) / current_unorm
if OPTIMIZER_ID == 3: # ADAM
s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals
s2_vals = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals
correction1 = 1.0 - beta1_step
correction2 = tl.sqrt(1.0 - beta2_step)
step_size = -lr * correction2 / correction1
if weight_decay > 0.0:
p_vals = p_vals * (1.0 - lr * weight_decay)
update_val = update_scale * step_size * (s1_vals / (tl.sqrt(s2_vals) + eps * correction2))
p_vals = p_vals + update_val
elif OPTIMIZER_ID == 5: # ADEMAMIX
s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals # m1
s3_vals = s3_vals * beta3 + (1.0 - beta3) * g_vals # m2
s2_vals = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals # nu
correction1 = 1.0 - beta1_step
correction2 = tl.sqrt(1.0 - beta2_step)
if weight_decay > 0.0:
p_vals = p_vals * (1.0 - lr * weight_decay)
mixed_momentum = (s1_vals / correction1) + (alpha * s3_vals)
adaptive_term = (tl.sqrt(s2_vals) / correction2) + eps
p_vals = p_vals - lr * (mixed_momentum / adaptive_term)
tl.store(p_ptr + offsets, p_vals, mask=mask)
tl.store(state1_ptr + offsets, s1_vals, mask=mask)
tl.store(state2_ptr + offsets, s2_vals, mask=mask)
if OPTIMIZER_ID == 5: # ADEMAMIX
tl.store(state1_ptr + n_elements + offsets, s3_vals, mask=mask)
@triton.jit
def _optimizer_update_1state_32bit_triton_kernel(
g_ptr,
p_ptr,
state1_ptr,
state2_ptr,
unorm_ptr,
max_unorm: tl.constexpr,
param_norm,
beta1: tl.constexpr,
beta2: tl.constexpr,
beta3,
alpha,
eps: tl.constexpr,
weight_decay: tl.constexpr,
step,
beta1_step,
beta2_step,
lr,
gnorm_scale: tl.constexpr,
skip_zeros,
n_elements,
OPTIMIZER_ID: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
N_PER_TH: tl.constexpr,
):
"""1-state optimizer kernel"""
pid = tl.program_id(axis=0)
block_start_idx = pid * N_PER_TH
offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH)
mask = offsets < n_elements
g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
p_vals = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0)
g_vals = gnorm_scale * g_vals
if weight_decay > 0.0:
g_vals = g_vals + p_vals * weight_decay
update_scale = 1.0
if max_unorm > 0.0:
current_unorm = tl.sqrt(tl.load(unorm_ptr))
if current_unorm > max_unorm * param_norm + eps:
update_scale = (max_unorm * param_norm + eps) / current_unorm
if OPTIMIZER_ID == 0: # MOMENTUM
if step == 1:
s1_vals = g_vals
else:
s1_vals = s1_vals * beta1 + g_vals
update_val = update_scale * (-lr * s1_vals)
p_vals = p_vals + update_val
elif OPTIMIZER_ID == 4: # LION
momentum_update = s1_vals * beta1 + (1.0 - beta1) * g_vals
update_val = update_scale * lr * tl.where(momentum_update > 0, 1.0, tl.where(momentum_update < 0, -1.0, 0.0))
p_vals = p_vals - update_val
s1_vals = s1_vals * beta2 + (1.0 - beta2) * g_vals
elif OPTIMIZER_ID == 1: # RMSPROP
s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals * g_vals
update_val = update_scale * lr * g_vals / (tl.sqrt(s1_vals) + eps)
p_vals = p_vals - update_val
elif OPTIMIZER_ID == 2: # ADAGRAD
s1_vals = s1_vals + g_vals * g_vals
update_val = lr * g_vals / (tl.sqrt(s1_vals) + eps)
p_vals = p_vals - update_val
tl.store(p_ptr + offsets, p_vals, mask=mask)
tl.store(state1_ptr + offsets, s1_vals, mask=mask)
name2optimizer_32bit_fn = {
"adam": {
"preprocess": _optimizer_precondition_2state_32bit,
"update": _optimizer_update_2state_32bit_triton_kernel,
},
"ademamix": {
"preprocess": _optimizer_precondition_2state_32bit,
"update": _optimizer_update_2state_32bit_triton_kernel,
},
"momentum": {
"preprocess": _optimizer_precondition_1state_32bit,
"update": _optimizer_update_1state_32bit_triton_kernel,
},
"rmsprop": {
"preprocess": _optimizer_precondition_1state_32bit,
"update": _optimizer_update_1state_32bit_triton_kernel,
},
"adagrad": {
"preprocess": _optimizer_precondition_1state_32bit,
"update": _optimizer_update_1state_32bit_triton_kernel,
},
"lion": {
"preprocess": _optimizer_precondition_1state_32bit,
"update": _optimizer_update_1state_32bit_triton_kernel,
},
}
def optimizer_update_32bit_impl(
optimizer_name: str,
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
unorm_vec: Optional[torch.Tensor],
max_unorm: float,
param_norm: float,
beta1: float,
beta2: float,
beta3: float,
alpha: float,
eps: float,
weight_decay: float,
step: int,
lr: float,
gnorm_scale: float = 1.0,
skip_zeros=False,
) -> None:
"""
32-bit optimizer implemented by Triton
"""
if skip_zeros:
raise NotImplementedError("skip_zeros is not supported on XPU yet")
BLOCK_SIZE = 256
N_PER_TH = 1 # Number of blocks processed per thread.
grid = (triton.cdiv(p.numel(), BLOCK_SIZE * N_PER_TH),)
optimizer_id = name2optimizer_id[optimizer_name]
fn_preprocess = name2optimizer_32bit_fn[optimizer_name]["preprocess"]
fn_update = name2optimizer_32bit_fn[optimizer_name]["update"]
# In torch=2.7 on XPU there is an issue with libdevice.pow, leading to an error.
# For backwards compatibility we precompute the bias correction factors.
beta1_step = beta1**step
beta2_step = beta2**step
if optimizer_name == "lion":
fn_update[grid](
g, p, state1, state2, unorm_vec, max_unorm, param_norm,
beta1, beta2, beta3, alpha, eps, weight_decay, step,
beta1_step, beta2_step, lr, gnorm_scale, skip_zeros,
p.numel(), optimizer_id, BLOCK_SIZE, N_PER_TH, num_warps=2,
)
if max_unorm > 0.0:
unorm_vec.zero_()
fn_preprocess[grid](
g, p, state1, state2, unorm_vec,
beta1, beta2, eps, weight_decay, step,
beta1_step, beta2_step, lr, gnorm_scale,
p.numel(), optimizer_id, BLOCK_SIZE, N_PER_TH, num_warps=2,
)
else:
if max_unorm > 0.0:
unorm_vec.zero_()
fn_preprocess[grid](
g, p, state1, state2, unorm_vec,
beta1, beta2, eps, weight_decay, step,
beta1_step, beta2_step, lr, gnorm_scale,
p.numel(), optimizer_id, BLOCK_SIZE, N_PER_TH, num_warps=2,
)
fn_update[grid](
g, p, state1, state2, unorm_vec, max_unorm, param_norm,
beta1, beta2, beta3, alpha, eps, weight_decay, step,
beta1_step, beta2_step, lr, gnorm_scale, skip_zeros,
p.numel(), optimizer_id, BLOCK_SIZE, N_PER_TH, num_warps=2,
)
from collections.abc import Sequence
from typing import Optional
import torch
from . import triton_kernels
from . import triton_kernels, kernels_optim
# currently codes unused, kept for reference
# Should be the same for quant/dequant
......@@ -175,3 +176,46 @@ def gemv_4bit(
B_dq_triton,
bias=None,
)
def optimizer_update_32bit(
optimizer_name: str,
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
unorm_vec: Optional[torch.Tensor],
max_unorm: float,
param_norm: float,
beta1: float,
beta2: float,
beta3: float,
alpha: float,
eps: float,
weight_decay: float,
step: int,
lr: float,
gnorm_scale: float,
skip_zeros=False,
) -> None:
with torch_accelerator_module.device(state1.device):
kernels_optim.optimizer_update_32bit_impl(
optimizer_name=optimizer_name,
g=g,
p=p,
state1=state1,
state2=state2,
unorm_vec=unorm_vec,
max_unorm=max_unorm,
param_norm=param_norm,
beta1=beta1,
beta2=beta2,
beta3=beta3,
alpha=alpha,
eps=eps,
weight_decay=weight_decay,
step=step,
lr=lr,
gnorm_scale=gnorm_scale,
skip_zeros=skip_zeros,
)
......@@ -65,5 +65,6 @@ elif triton_available:
register_kernel("bitsandbytes::dequantize_4bit.out", "xpu")(triton_ops.dequantize_4bit_inplace)
register_kernel("bitsandbytes::dequantize_4bit", "xpu")(triton_ops.dequantize_4bit)
register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit)
register_kernel("bitsandbytes::optimizer_update_32bit", "xpu")(triton_ops.optimizer_update_32bit)
else:
warnings.warn("XPU available but no ipex or triton packages found.")
......@@ -178,6 +178,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name, device):
if optim_name.startswith("paged_") and sys.platform == "win32":
pytest.skip("Paged optimizers can have issues on Windows.")
if optim_name.startswith("paged_") and device == "xpu":
pytest.skip("Paged optimizers are not supported on XPU currently.")
if gtype == torch.bfloat16 and optim_name in ["momentum", "rmsprop"]:
pytest.skip()
if dim1 == 1 and dim2 == 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment