"dev/modal/benchmarks.py" did not exist on "fe5cd1fcc6d5ddf0e39a41a33223cf3377548c7f"
Commit 9b0e3a30 authored by cmx's avatar cmx
Browse files

first commit

parent fe5cd1fc
Pipeline #3450 failed with stages
in 0 seconds
import math
import torch
from benchmark_model_configs import compute_seq_len_sweep_config
from benchmark_model_configs import estimate_kernel_peak_memory
from benchmark_model_configs import get_benchmark_model_config
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaMLP
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from utils import run_memory_benchmark
from utils import run_speed_benchmark
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.utils import infer_device
device = infer_device()
def _setup_geglu(input: SingleBenchmarkRunInput):
"""Create input tensor and GEGLU layer from benchmark config."""
cfg = input.extra_benchmark_config
llama_config = LlamaConfig(
hidden_size=cfg["hidden_size"],
intermediate_size=cfg["intermediate_size"],
hidden_act=cfg["hidden_act"],
)
x = torch.randn(
cfg["bsz"],
input.x,
cfg["hidden_size"],
device=device,
dtype=cfg["dtype"],
requires_grad=True,
)
if input.kernel_provider == "liger":
layer = LigerGEGLUMLP(config=llama_config).to(device).to(cfg["dtype"])
elif input.kernel_provider == "huggingface":
layer = LlamaMLP(config=llama_config).to(device).to(cfg["dtype"])
else:
raise ValueError(f"Invalid provider: {input.kernel_provider} for GEGLU")
return x, layer
def bench_speed_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
x, layer = _setup_geglu(input)
return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x])
def bench_memory_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
x, layer = _setup_geglu(input)
return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode)
if __name__ == "__main__":
args = parse_benchmark_script_args()
model = get_benchmark_model_config(args.model)
probe_seq_len = 1024
def _probe():
probe_input = SingleBenchmarkRunInput(
x=probe_seq_len,
kernel_provider="huggingface",
extra_benchmark_config={
"bsz": 1,
"hidden_size": model.hidden_size,
"intermediate_size": model.intermediate_size,
"hidden_act": "gelu_pytorch_tanh",
"dtype": model.dtype,
},
)
x, layer = _setup_geglu(probe_input)
return layer(x)
peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe)
kernel_bpt = peak_bytes // probe_seq_len
config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt)
common_configs = {
"kernel_name": "geglu",
"x_name": "T",
"x_label": "sequence length",
"x_values": [2**i for i in range(10, int(math.log2(config.seq_len)) + 1)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"bsz": config.batch_size,
"hidden_size": model.hidden_size,
"intermediate_size": model.intermediate_size,
"hidden_act": "gelu_pytorch_tanh",
"dtype": model.dtype,
}
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_geglu,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_geglu,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
import torch
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.transformers.group_norm import LigerGroupNorm
from liger_kernel.utils import infer_device
device = infer_device()
def bench_speed_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
C = input.x
provider = input.kernel_provider
mode = input.kernel_operation_mode
extra_benchmark_config = input.extra_benchmark_config
M = extra_benchmark_config["M"]
H = extra_benchmark_config["H"]
channels_per_group = extra_benchmark_config["channels_per_group"]
eps = extra_benchmark_config["eps"]
dtype = extra_benchmark_config["dtype"]
x_shape = (M, C, H)
triton_ln = LigerGroupNorm(num_channels=C, num_groups=C // channels_per_group, eps=eps).to(device)
torch_ln = torch.nn.GroupNorm(num_groups=C // channels_per_group, num_channels=C, eps=eps).to(device)
x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)
def y_fwd():
if provider == "liger":
return triton_ln(x)
if provider == "huggingface":
return torch_ln(x)
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500)
elif mode == "backward":
y = y_fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(dy, retain_graph=True),
quantiles=QUANTILES,
grad_to_none=[x],
rep=500,
)
elif mode == "full":
def full():
y = y_fwd()
y.backward(dy, retain_graph=True)
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=[x], rep=500)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
def bench_memory_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
C = input.x
provider = input.kernel_provider
extra_benchmark_config = input.extra_benchmark_config
M = extra_benchmark_config["M"]
H = extra_benchmark_config["H"]
channels_per_group = extra_benchmark_config["channels_per_group"]
eps = extra_benchmark_config["eps"]
dtype = extra_benchmark_config["dtype"]
x_shape = (M, C, H)
triton_ln = LigerGroupNorm(num_channels=C, num_groups=C // channels_per_group, eps=eps).to(device)
torch_ln = torch.nn.GroupNorm(num_groups=C // channels_per_group, num_channels=C, eps=eps).to(device)
x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)
def y_fwd():
if provider == "liger":
return triton_ln(x)
if provider == "huggingface":
return torch_ln(x)
def full():
y = y_fwd()
y.backward(dy, retain_graph=True)
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_configs = {
"kernel_name": "group_norm",
"x_name": "C",
"x_label": "num_channels",
"x_values": [2**i for i in range(5, 12)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"M": 128,
"H": 512,
"channels_per_group": 4,
"dtype": torch.float32,
"eps": 1e-6,
}
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_group_norm,
kernel_operation_modes=["forward", "full", "backward"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_group_norm,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
import os
import sys
import torch
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.utils import infer_device
device = infer_device()
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
#############################################################################
# Test the memory consumption of the linear fused GRPO loss
#############################################################################
def bench_memory_fused_linear_grpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_grpo_loss import LigerLMHeadGRPO
from test.chunked_loss.test_grpo_loss import TorchLMHeadGRPO
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
importance_sampling_level = input.extra_benchmark_config["importance_sampling_level"]
provider = input.kernel_provider
# Instantiate once and retrieve the first output only
torch_lm_head_grpo = TorchLMHeadGRPO(H=H, V=V, dtype=dtype, importance_sampling_level=importance_sampling_level).to(
device
)
liger_lm_head_grpo = LigerLMHeadGRPO(H=H, V=V, dtype=dtype, importance_sampling_level=importance_sampling_level).to(
device
)
# Create inputs
_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
selected_token_ids = torch.randint(0, V, (B, T), dtype=torch.long, device=device)
attention_mask = torch.ones(B, T, device=device)
advantages = torch.randn(B, dtype=dtype, device=device)
ref_input = torch.randn(B, T, H, dtype=dtype, device=device)
torch_fwd = lambda: torch_lm_head_grpo(_input, selected_token_ids, attention_mask, advantages, ref_input=ref_input)[
0
]
liger_fwd = lambda: liger_lm_head_grpo(_input, selected_token_ids, attention_mask, advantages, ref_input=ref_input)[
0
]
def fwd():
if provider == "liger":
return liger_fwd()
elif provider == "torch":
return torch_fwd()
def full():
y = fwd()
y.backward()
mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
#############################################################################
# Test the speed of the fused linear GRPO loss
#############################################################################
def bench_speed_fused_linear_grpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_grpo_loss import LigerLMHeadGRPO
from test.chunked_loss.test_grpo_loss import TorchLMHeadGRPO
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
importance_sampling_level = input.extra_benchmark_config["importance_sampling_level"]
provider = input.kernel_provider
mode = input.kernel_operation_mode
# Instantiate once and retrieve the first output only
torch_lm_head_grpo = TorchLMHeadGRPO(H=H, V=V, dtype=dtype, importance_sampling_level=importance_sampling_level).to(
device
)
liger_lm_head_grpo = LigerLMHeadGRPO(H=H, V=V, dtype=dtype, importance_sampling_level=importance_sampling_level).to(
device
)
# Create inputs
_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
selected_token_ids = torch.randint(0, V, (B, T), dtype=torch.long, device=device)
attention_mask = torch.ones(B, T, device=device)
advantages = torch.randn(B, dtype=dtype, device=device)
ref_input = torch.randn(B, T, H, dtype=dtype, device=device)
torch_fwd = lambda: torch_lm_head_grpo(_input, selected_token_ids, attention_mask, advantages, ref_input=ref_input)[
0
]
liger_fwd = lambda: liger_lm_head_grpo(_input, selected_token_ids, attention_mask, advantages, ref_input=ref_input)[
0
]
def fwd():
if provider == "liger":
return liger_fwd()
elif provider == "torch":
return torch_fwd()
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=[_input],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":
def full():
y = fwd()
y.backward()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
rep=100,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
# Benchmark token-level importance sampling (original GRPO)
token_configs = {
"kernel_name": "fused_linear_grpo_loss_token",
"x_name": "B",
"x_label": "B",
"x_values": [2**i for i in range(1, 5)],
"kernel_providers": ["liger", "torch"],
"extra_benchmark_configs": [
{
"T": 1024,
"H": 4096,
"V": 128256,
"importance_sampling_level": "token",
"dtype": torch.bfloat16,
}
],
"overwrite": args.overwrite,
}
# Benchmark sequence-level importance sampling (GSPO)
sequence_configs = {
"kernel_name": "fused_linear_grpo_loss_sequence",
"x_name": "B",
"x_label": "B",
"x_values": [2**i for i in range(1, 5)],
"kernel_providers": ["liger", "torch"],
"extra_benchmark_configs": [
{
"T": 1024,
"H": 4096,
"V": 128256,
"importance_sampling_level": "sequence",
"dtype": torch.bfloat16,
}
],
"overwrite": args.overwrite,
}
# Run benchmarks for token-level (GRPO)
print("Benchmarking GRPO (token-level importance sampling)...")
run_benchmarks(
bench_test_fn=bench_speed_fused_linear_grpo_loss,
kernel_operation_modes=["forward", "full", "backward"],
metric_name="speed",
metric_unit="ms",
**token_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_fused_linear_grpo_loss,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**token_configs,
)
# Run benchmarks for sequence-level (GSPO)
print("Benchmarking GSPO (sequence-level importance sampling)...")
run_benchmarks(
bench_test_fn=bench_speed_fused_linear_grpo_loss,
kernel_operation_modes=["forward", "full", "backward"],
metric_name="speed",
metric_unit="ms",
**sequence_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_fused_linear_grpo_loss,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**sequence_configs,
)
import torch
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.transformers.jsd import LigerJSD
from liger_kernel.utils import get_total_gpu_memory
from liger_kernel.utils import infer_device
device = infer_device()
class TorchJSD(torch.nn.Module):
def __init__(
self,
beta: float = 0.5,
ignore_index: int = -100,
dtype: torch.dtype = torch.float,
):
super(TorchJSD, self).__init__()
self.kl = torch.nn.KLDivLoss(reduction="none", log_target=True)
self.beta = beta
self.ignore_index = ignore_index
self.dtype = dtype
def forward(
self,
log_q: torch.Tensor, # input
log_p: torch.Tensor, # target
label=None,
):
log_p, log_q = log_p.to(torch.float), log_q.to(torch.float)
log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1))
m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta)
loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + (1 - self.beta) * self.kl(
torch.log(m), log_q
).sum(dim=-1)
if label is not None:
loss = torch.where(label != self.ignore_index, loss, 0.0)
n_non_ignore = (label != self.ignore_index).sum().item()
if n_non_ignore == 0:
loss = 0.0
else:
loss = (loss / n_non_ignore).sum()
else:
loss = (loss / log_q.shape[0]).sum()
return loss.to(self.dtype)
def bench_speed_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
V = input.x
B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
torch_jsd = TorchJSD()
liger_jsd = LigerJSD()
_input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(dim=-1)
target = torch.randn(B * T, V, device=device).log_softmax(dim=-1)
def fwd():
if input.kernel_provider == "liger":
return liger_jsd(_input, target)
else:
return torch_jsd(_input, target)
if input.kernel_operation_mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100)
elif input.kernel_operation_mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
quantiles=QUANTILES,
grad_to_none=[_input],
rep=100,
)
elif input.kernel_operation_mode == "full":
def full():
y = fwd()
y.backward(retain_graph=True)
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
def bench_memory_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
torch_jsd = TorchJSD()
liger_jsd = LigerJSD()
V = input.x
B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
_input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(dim=-1)
target = torch.randn(B * T, V, device=device).log_softmax(dim=-1)
def fwd():
if input.kernel_provider == "liger":
return liger_jsd(_input, target)
else:
return torch_jsd(_input, target)
def full():
y = fwd()
y.backward(retain_graph=True)
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
gpu_memory_gbs = get_total_gpu_memory()
# We know that the full test will require 54GBs for vocab size 2^17 on torch
if gpu_memory_gbs >= 54:
x_max = 17
else:
x_max = 16
common_args = {
"kernel_name": "jsd",
"x_name": "V",
"x_label": "vocab size",
"x_values": [2**i for i in range(12, x_max + 1)],
"kernel_providers": ["liger", "torch"],
"extra_benchmark_configs": [{"B": 4, "T": 2048}],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_memory_jsd,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_args,
)
run_benchmarks(
bench_test_fn=bench_speed_jsd,
kernel_operation_modes=["forward", "backward", "full"],
metric_name="speed",
metric_unit="ms",
**common_args,
)
import torch
import torch.nn as nn
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.transformers.kl_div import LigerKLDIVLoss
from liger_kernel.utils import infer_device
device = infer_device()
S, E = 12, 18
def bench_speed_kldiv(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
reduction = "batchmean"
V = input.x
B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
torch_kl_div = nn.KLDivLoss(reduction=reduction)
liger_kl_div = LigerKLDIVLoss(reduction=reduction)
_input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(dim=-1)
target = torch.randn(B * T, V, device=device).softmax(dim=-1)
def fwd():
if input.kernel_provider == "liger":
return liger_kl_div(_input, target)
else:
return torch_kl_div(_input, target)
if input.kernel_operation_mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100)
elif input.kernel_operation_mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
quantiles=QUANTILES,
grad_to_none=[_input],
rep=100,
)
elif input.kernel_operation_mode == "full":
def full():
y = fwd()
y.backward(retain_graph=True)
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
def bench_memory_kldiv(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
reduction = "batchmean"
torch_kl_div = nn.KLDivLoss(reduction=reduction)
liger_kl_div = LigerKLDIVLoss(reduction=reduction)
V = input.x
B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
_input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(dim=-1)
target = torch.randn(B * T, V, device=device).softmax(dim=-1)
def fwd():
if input.kernel_provider == "liger":
return liger_kl_div(_input, target)
else:
return torch_kl_div(_input, target)
def full():
y = fwd()
y.backward(retain_graph=True)
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_args = {
"kernel_name": "kl_div",
"x_name": "V",
"x_label": "vocab size",
"x_values": [2**i for i in range(12, 18)],
"kernel_providers": ["liger", "torch"],
"extra_benchmark_configs": [{"B": 8, "T": 512}],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_memory_kldiv,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_args,
)
run_benchmarks(
bench_test_fn=bench_speed_kldiv,
kernel_operation_modes=["forward", "backward", "full"],
metric_name="speed",
metric_unit="ms",
**common_args,
)
import os
import sys
import torch
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.chunked_loss import LigerFusedLinearKTOLoss
from liger_kernel.utils import infer_device
device = infer_device()
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
class TorchLMHeadKTO(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
use_bias: bool = False,
use_ref_bias: bool = False,
ignore_index: int = -100,
beta: float = 0.1,
):
from test.chunked_loss.test_kto_loss import HFKTOLoss
super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=use_bias, dtype=dtype)
self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=use_ref_bias, dtype=dtype)
self.KTO_loss = HFKTOLoss(
ignore_index=ignore_index,
beta=beta,
use_ref_model=True,
).get_batch_loss_metrics
def forward(self, x, ref_x, y, preference_labels, kl=None):
return self.KTO_loss(
weight=self.lin.weight,
_input=x,
target=y,
bias=self.lin.bias,
ref_input=ref_x,
ref_weight=self.ref_lin.weight,
ref_bias=self.ref_lin.bias,
preference_labels=preference_labels,
kl=kl,
)
class LigerLMHeadKTO(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
use_bias: bool = False,
use_ref_bias: bool = False,
ignore_index: int = -100,
beta: float = 0.1,
):
super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=use_bias, dtype=dtype)
self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=use_ref_bias, dtype=dtype)
self.KTO_loss = LigerFusedLinearKTOLoss(
ignore_index=ignore_index,
beta=beta,
use_ref_model=True,
)
def forward(self, x, ref_x, y, preference_labels, kl=None):
return self.KTO_loss(
_input=x,
lin_weight=self.lin.weight,
target=y,
preference_labels=preference_labels,
bias=self.lin.bias,
ref_input=ref_x,
ref_weight=self.ref_lin.weight,
ref_bias=self.ref_lin.bias,
kl=kl,
)
def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
bias = input.extra_benchmark_config["bias"]
beta = input.extra_benchmark_config["beta"]
ignore_index = input.extra_benchmark_config["ignore_index"]
provider = input.kernel_provider
torch_kto_loss = TorchLMHeadKTO(
H=H,
V=V,
dtype=dtype,
use_bias=bias,
use_ref_bias=bias,
ignore_index=ignore_index,
beta=beta,
).to(device)
liger_kto_loss = LigerLMHeadKTO(
H=H,
V=V,
dtype=dtype,
use_bias=bias,
use_ref_bias=bias,
ignore_index=ignore_index,
beta=beta,
).to(device)
# Input shape: [B, T, H]
_input = torch.randn(B, T, H, device=device, dtype=dtype)
# Target shape: [B, T]
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
# Preference labels shape: [B]
# Create binary preference labels (0 or 1) for each sequence in the batch
# Used to indicate preferred sequences (1) vs non-preferred sequences (0)
preference_labels = torch.randint(2, (B,), dtype=torch.bool, device=device)
# Precomputed KL divergence between policy and reference distributions
kl = torch.randn(1, device=device, dtype=dtype)
# Add ignore_index tokens to simulate padding
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
target.view(-1)[indices_to_assign] = ignore_index
# Add ref_x with the same shape as _input
ref_input = torch.randn(B, T, H, device=device, dtype=dtype)
def fwd():
if provider == "liger":
return liger_kto_loss(
x=_input,
ref_x=ref_input,
y=target,
preference_labels=preference_labels,
kl=kl,
)[0]
elif provider == "huggingface":
return torch_kto_loss(
x=_input,
ref_x=ref_input,
y=target,
preference_labels=preference_labels,
kl=kl,
)[0]
def full():
y = fwd()
y.backward()
mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
bias = input.extra_benchmark_config["bias"]
beta = input.extra_benchmark_config["beta"]
ignore_index = input.extra_benchmark_config["ignore_index"]
provider = input.kernel_provider
mode = input.kernel_operation_mode
torch_kto_loss = TorchLMHeadKTO(
H=H,
V=V,
dtype=dtype,
beta=beta,
ignore_index=ignore_index,
use_bias=bias,
).to(device)
liger_kto_loss = LigerLMHeadKTO(
H=H,
V=V,
dtype=dtype,
beta=beta,
ignore_index=ignore_index,
use_bias=bias,
).to(device)
# Input shape: [B, T, H]
_input = torch.randn(B, T, H, device=device, dtype=dtype)
# Target shape: [B, T]
target = torch.randint(V, (B, T), device=device, dtype=torch.long)
# Preference labels shape: [B]
# Create binary preference labels (0 or 1) for each sequence in the batch
# Used to indicate preferred sequences (1) vs non-preferred sequences (0)
preference_labels = torch.randint(2, (B,), dtype=torch.bool, device=device)
# Precomputed KL divergence between policy and reference distributions
kl = torch.randn(1, device=device, dtype=dtype)
# Add ignore_index tokens
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
target.view(-1)[indices_to_assign] = ignore_index
# Add ref_x with the same shape as _input
ref_input = torch.randn(B, T, H, device=device, dtype=dtype)
def fwd():
if provider == "liger":
return liger_kto_loss(
x=_input,
ref_x=ref_input,
y=target,
preference_labels=preference_labels,
kl=kl,
)[0]
elif provider == "huggingface":
return torch_kto_loss(
x=_input,
ref_x=ref_input,
y=target,
preference_labels=preference_labels,
kl=kl,
)[0]
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=[_input],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":
def full():
y = fwd()
y.backward()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
rep=100,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_configs = {
"kernel_name": "kto_loss",
"x_name": "B",
"x_label": "Batch Size (B)",
"x_values": [2**i for i in range(1, 6)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"T": 512,
"H": 1024,
"V": 128256,
"mode": "forward",
"dtype": torch.bfloat16,
"bias": True,
"beta": 0.1,
"ignore_index": 42,
}
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_kto_loss,
kernel_operation_modes=["forward", "backward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_kto_loss,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
import torch
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.utils import infer_device
device = infer_device()
def bench_speed_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
N = input.x
provider = input.kernel_provider
mode = input.kernel_operation_mode
extra_benchmark_config = input.extra_benchmark_config
M = extra_benchmark_config["M"]
eps = extra_benchmark_config["eps"]
dtype = extra_benchmark_config["dtype"]
x_shape = (M, N)
triton_ln = LigerLayerNorm(hidden_size=N).to(device)
torch_ln = torch.nn.LayerNorm(N, eps=eps).to(device)
x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)
def y_fwd():
if provider == "liger":
return triton_ln(x)
if provider == "huggingface":
return torch_ln(x)
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500)
elif mode == "backward":
y = y_fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(dy, retain_graph=True),
quantiles=QUANTILES,
grad_to_none=[x],
rep=500,
)
elif mode == "full":
def full():
y = y_fwd()
y.backward(dy, retain_graph=True)
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=[x], rep=500)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
def bench_memory_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
N = input.x
provider = input.kernel_provider
dtype = input.extra_benchmark_config["dtype"]
M = input.extra_benchmark_config["M"]
eps = input.extra_benchmark_config["eps"]
x_shape = (M, N)
triton_ln = LigerLayerNorm(hidden_size=N).to(device)
torch_ln = torch.nn.LayerNorm(N, eps=eps).to(device)
x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)
def y_fwd():
if provider == "liger":
return triton_ln(x)
if provider == "huggingface":
return torch_ln(x)
def full():
y = y_fwd()
y.backward(dy, retain_graph=True)
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_configs = {
"kernel_name": "layer_norm",
"x_name": "N",
"x_label": "hidden size",
"x_values": [2**i for i in range(10, 15)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [{"M": 4096, "dtype": torch.float32, "eps": 1e-6}],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_layer_norm,
kernel_operation_modes=["forward", "backward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_layer_norm,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
import torch
import triton
from transformers.models.llama4.configuration_llama4 import Llama4TextConfig
from transformers.models.llama4.modeling_llama4 import Llama4TextRotaryEmbedding
from transformers.models.llama4.modeling_llama4 import apply_rotary_emb
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.transformers.llama4_rope import liger_llama4_text_rotary_pos_emb
from liger_kernel.utils import infer_device
from liger_kernel.utils import transformers_version_dispatch
device = infer_device()
def bench_speed_llama4_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
provider = input.kernel_provider
mode = input.kernel_operation_mode
extra_benchmark_config = input.extra_benchmark_config
num_q_heads = extra_benchmark_config["num_q_heads"]
num_kv_heads = extra_benchmark_config["num_kv_heads"]
dtype = extra_benchmark_config["dtype"]
# x can be either hidden_size or seq_len
hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x
seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x
head_dim = hidden_size // num_q_heads
# Create Llama4TextConfig for the rotary embedding
config = Llama4TextConfig(
hidden_size=hidden_size,
num_attention_heads=num_q_heads,
num_key_value_heads=num_kv_heads,
head_dim=head_dim,
max_position_embeddings=seq_len,
)
rotary_emb = transformers_version_dispatch(
"4.48.0",
Llama4TextRotaryEmbedding,
Llama4TextRotaryEmbedding,
before_kwargs={"config": config, "device": device},
after_kwargs={"config": config, "device": device},
)
q = torch.randn(
(1, seq_len, num_q_heads, head_dim),
device=device,
requires_grad=True,
dtype=dtype,
)
k = torch.randn(
(1, seq_len, num_kv_heads, head_dim),
device=device,
requires_grad=True,
dtype=dtype,
)
dq, dk = (
torch.randn_like(q, device=device, dtype=dtype),
torch.randn_like(k, device=device),
)
pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
freqs_cis = rotary_emb(q, pos_ids)
def fwd():
if provider == "liger":
return liger_llama4_text_rotary_pos_emb(q, k, freqs_cis)
elif provider == "huggingface":
return apply_rotary_emb(q, k, freqs_cis)
else:
raise ValueError(f"Invalid provider: {provider} for Llama4 RoPE embedding")
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
grad_to_none=[q, k],
rep=400,
quantiles=QUANTILES,
)
elif mode == "backward":
q_out, k_out = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True),
grad_to_none=[q, k],
rep=400,
quantiles=QUANTILES,
)
elif mode == "full":
def full():
q_out, k_out = fwd()
torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True)
ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
grad_to_none=[q, k],
rep=400,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
def bench_memory_llama4_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
provider = input.kernel_provider
extra_benchmark_config = input.extra_benchmark_config
num_q_heads = extra_benchmark_config["num_q_heads"]
num_kv_heads = extra_benchmark_config["num_kv_heads"]
dtype = extra_benchmark_config["dtype"]
# x can be either hidden_size or seq_len
hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x
seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x
head_dim = hidden_size // num_q_heads
# Create Llama4TextConfig for the rotary embedding
config = Llama4TextConfig(
hidden_size=hidden_size,
num_attention_heads=num_q_heads,
num_key_value_heads=num_kv_heads,
head_dim=head_dim,
max_position_embeddings=seq_len,
)
rotary_emb = transformers_version_dispatch(
"4.48.0",
Llama4TextRotaryEmbedding,
Llama4TextRotaryEmbedding,
before_kwargs={"config": config, "device": device},
after_kwargs={"config": config, "device": device},
)
q = torch.randn(
(1, seq_len, num_q_heads, head_dim),
device=device,
requires_grad=True,
dtype=dtype,
)
k = torch.randn(
(1, seq_len, num_kv_heads, head_dim),
device=device,
requires_grad=True,
dtype=dtype,
)
dq, dk = (
torch.randn_like(q, device=device, dtype=dtype),
torch.randn_like(k, device=device),
)
pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
freqs_cis = rotary_emb(q, pos_ids)
def full():
if provider == "liger":
q_out, k_out = liger_llama4_text_rotary_pos_emb(q, k, freqs_cis)
else:
q_out, k_out = apply_rotary_emb(q, k, freqs_cis)
torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True)
mem_50, mem_20, mem_80 = _test_memory(
full,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_configs_varying_hidden_size = {
"kernel_name": "llama4_rope",
"x_name": "H",
"x_label": "hidden size",
"x_values": [32 * (2**i) for i in range(4, 10, 2)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"dtype": torch.bfloat16,
"seq_len": 2048,
"num_q_heads": 32,
"num_kv_heads": 8,
}
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_llama4_rope,
kernel_operation_modes=["forward", "backward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs_varying_hidden_size,
)
run_benchmarks(
bench_test_fn=bench_memory_llama4_rope,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs_varying_hidden_size,
)
common_configs_varying_seq_len = {
"kernel_name": "llama4_rope",
"x_name": "T",
"x_label": "sequence length",
"x_values": [2**i for i in range(10, 15)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"dtype": torch.bfloat16,
"hidden_size": 8192,
"num_q_heads": 32,
"num_kv_heads": 8,
}
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_llama4_rope,
kernel_operation_modes=["forward", "backward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs_varying_seq_len,
)
run_benchmarks(
bench_test_fn=bench_memory_llama4_rope,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs_varying_seq_len,
)
import os
import sys
import torch
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.transformers.functional import liger_mhc_coeffs
from liger_kernel.transformers.functional import liger_mhc_post_res
from liger_kernel.transformers.functional import liger_mhc_pre
from liger_kernel.utils import infer_device
device = infer_device()
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
def bench_speed_mhc(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
from test.transformers.test_mhc import mhc_coeffs_ref
T = input.x
B = input.extra_benchmark_config["B"]
HC = input.extra_benchmark_config["HC"]
C = input.extra_benchmark_config["C"]
sub_kernel = input.extra_benchmark_config["sub_kernel"]
tmax = input.extra_benchmark_config["tmax"]
rms_eps = input.extra_benchmark_config["rms_eps"]
pre_eps = input.extra_benchmark_config["pre_eps"]
sinkhorn_eps = input.extra_benchmark_config["sinkhorn_eps"]
post_mult = input.extra_benchmark_config["post_mult"]
provider = input.kernel_provider
mode = input.kernel_operation_mode
coeffs_cfg = dict(tmax=tmax, rms_eps=rms_eps, pre_eps=pre_eps, sinkhorn_eps=sinkhorn_eps, post_mult=post_mult)
need_grad = mode in ("backward", "full")
x = torch.randn(B, T, HC, C, device=device, dtype=torch.bfloat16, requires_grad=need_grad)
K, M = HC * C, HC * HC + 2 * HC
phi = (torch.randn(K, M, device=device, dtype=torch.bfloat16) * 0.02).requires_grad_(need_grad)
b_param = torch.zeros(M, device=device, dtype=torch.float32, requires_grad=need_grad)
alpha_pre = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=need_grad)
alpha_post = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=need_grad)
alpha_res = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=need_grad)
grad_to_none = [x, phi, b_param, alpha_pre, alpha_post, alpha_res] if need_grad else None
if sub_kernel == "coeffs":
def fwd():
if provider == "liger":
return liger_mhc_coeffs(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg)
return mhc_coeffs_ref(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg)
def fwd_loss():
h_pre, h_post, h_res = fwd()
return h_pre.square().mean() + h_post.square().mean() + h_res.square().mean()
elif sub_kernel == "pre":
with torch.no_grad():
h_pre_c, _, _ = liger_mhc_coeffs(
x.detach(),
phi.detach(),
b_param.detach(),
alpha_pre.detach(),
alpha_post.detach(),
alpha_res.detach(),
**coeffs_cfg,
)
h_pre_c.requires_grad_(need_grad)
grad_to_none = [x, h_pre_c] if need_grad else None
def fwd():
if provider == "liger":
return liger_mhc_pre(x, h_pre_c)
return (x.float() * h_pre_c.unsqueeze(-1)).sum(dim=-2)
def fwd_loss():
return fwd().square().mean()
elif sub_kernel == "post_res":
with torch.no_grad():
_, h_post_c, h_res_c = liger_mhc_coeffs(
x.detach(),
phi.detach(),
b_param.detach(),
alpha_pre.detach(),
alpha_post.detach(),
alpha_res.detach(),
**coeffs_cfg,
)
h_post_c.requires_grad_(need_grad)
h_res_c.requires_grad_(need_grad)
f_out = torch.randn(B, T, C, device=device, dtype=torch.bfloat16, requires_grad=need_grad)
grad_to_none = [x, f_out, h_post_c, h_res_c] if need_grad else None
def fwd():
if provider == "liger":
return liger_mhc_post_res(x, f_out, h_post_c, h_res_c)
return torch.einsum("...oi,...ic->...oc", h_res_c, x.float()) + h_post_c.unsqueeze(
-1
) * f_out.float().unsqueeze(-2)
def fwd_loss():
return fwd().square().mean()
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES)
elif mode == "backward":
y = fwd_loss()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=grad_to_none,
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":
def full():
y = fwd_loss()
y.backward()
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=grad_to_none, rep=100, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80)
def bench_memory_mhc(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
from test.transformers.test_mhc import mhc_coeffs_ref
T = input.x
B = input.extra_benchmark_config["B"]
HC = input.extra_benchmark_config["HC"]
C = input.extra_benchmark_config["C"]
sub_kernel = input.extra_benchmark_config["sub_kernel"]
tmax = input.extra_benchmark_config["tmax"]
rms_eps = input.extra_benchmark_config["rms_eps"]
pre_eps = input.extra_benchmark_config["pre_eps"]
sinkhorn_eps = input.extra_benchmark_config["sinkhorn_eps"]
post_mult = input.extra_benchmark_config["post_mult"]
provider = input.kernel_provider
coeffs_cfg = dict(tmax=tmax, rms_eps=rms_eps, pre_eps=pre_eps, sinkhorn_eps=sinkhorn_eps, post_mult=post_mult)
x = torch.randn(B, T, HC, C, device=device, dtype=torch.bfloat16, requires_grad=True)
K, M = HC * C, HC * HC + 2 * HC
phi = (torch.randn(K, M, device=device, dtype=torch.bfloat16) * 0.02).requires_grad_(True)
b_param = torch.zeros(M, device=device, dtype=torch.float32, requires_grad=True)
alpha_pre = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True)
alpha_post = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True)
alpha_res = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True)
if sub_kernel == "coeffs":
def full():
if provider == "liger":
hp, hpo, hr = liger_mhc_coeffs(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg)
else:
hp, hpo, hr = mhc_coeffs_ref(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg)
(hp.square().mean() + hpo.square().mean() + hr.square().mean()).backward()
elif sub_kernel == "pre":
with torch.no_grad():
h_pre_c, _, _ = liger_mhc_coeffs(
x.detach(),
phi.detach(),
b_param.detach(),
alpha_pre.detach(),
alpha_post.detach(),
alpha_res.detach(),
**coeffs_cfg,
)
h_pre_c.requires_grad_(True)
def full():
if provider == "liger":
out = liger_mhc_pre(x, h_pre_c)
else:
out = (x.float() * h_pre_c.unsqueeze(-1)).sum(dim=-2)
out.square().mean().backward()
elif sub_kernel == "post_res":
with torch.no_grad():
_, h_post_c, h_res_c = liger_mhc_coeffs(
x.detach(),
phi.detach(),
b_param.detach(),
alpha_pre.detach(),
alpha_post.detach(),
alpha_res.detach(),
**coeffs_cfg,
)
h_post_c.requires_grad_(True)
h_res_c.requires_grad_(True)
f_out = torch.randn(B, T, C, device=device, dtype=torch.bfloat16, requires_grad=True)
def full():
if provider == "liger":
out = liger_mhc_post_res(x, f_out, h_post_c, h_res_c)
else:
out = torch.einsum("...oi,...ic->...oc", h_res_c, x.float()) + h_post_c.unsqueeze(
-1
) * f_out.float().unsqueeze(-2)
out.square().mean().backward()
mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80)
if __name__ == "__main__":
args = parse_benchmark_script_args()
for sub_kernel in ["coeffs", "pre", "post_res"]:
common_configs = {
"kernel_name": f"mhc_{sub_kernel}",
"x_name": "T",
"x_label": "Sequence Length (T)",
"x_values": [2**i for i in range(7, 12)],
"kernel_providers": ["liger", "torch"],
"extra_benchmark_configs": [
{
"B": 4,
"HC": 4,
"C": 4096,
"tmax": 20,
"rms_eps": 1e-6,
"pre_eps": 0.0,
"sinkhorn_eps": 1e-6,
"post_mult": 2.0,
"sub_kernel": sub_kernel,
}
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_mhc,
kernel_operation_modes=["forward", "backward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_mhc,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
This diff is collapsed.
"""
Standardized benchmark model configurations.
Provides canonical model architecture profiles and device-specific benchmark
parameters. All benchmark scripts should derive their tensor shapes from these
shared configs rather than defining ad-hoc per-script constants.
Usage::
from benchmark_model_configs import (
get_benchmark_model_config,
compute_seq_len_sweep_config,
estimate_kernel_peak_memory,
)
args = parse_benchmark_script_args()
model = get_benchmark_model_config(args.model)
# Measure actual memory via a small probe, then compute sweep config
peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe)
bpt = peak_bytes // probe_num_tokens
config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=bpt)
"""
import gc
import math
from dataclasses import dataclass
from typing import Callable
from typing import Dict
from typing import Optional
import torch
from liger_kernel.utils import get_total_gpu_memory
from liger_kernel.utils import infer_device
@dataclass(frozen=True)
class ModelConfig:
"""Canonical model architecture profile.
Each field corresponds to a standard LLM hyperparameter. Benchmark scripts
pick the fields they need (e.g. hidden_size for RMSNorm, vocab_size for
CrossEntropy) while kernel-specific overrides (e.g. hidden_act for GEGLU)
are applied locally in the benchmark script.
"""
name: str
hidden_size: int
intermediate_size: int
vocab_size: int
num_attention_heads: int
num_key_value_heads: int
head_dim: int
hidden_act: str
max_position_embeddings: int = 8192
rms_norm_eps: float = 1e-5
dtype: torch.dtype = torch.bfloat16
@dataclass(frozen=True)
class SeqLenSweepConfig:
"""Config for benchmarks that sweep sequence length (e.g. GEGLU, SwiGLU).
Attributes:
batch_size: Safe batch size for the sweep.
seq_len: Max sequence length (upper bound for x_values).
"""
batch_size: int
seq_len: int
@dataclass(frozen=True)
class HiddenSizeSweepConfig:
"""Config for benchmarks that sweep hidden_size with fixed BT (e.g. DyT).
Attributes:
bt: Fixed batch * seq dimension.
max_hidden_size: Upper bound for hidden_size sweep.
"""
bt: int
max_hidden_size: int
# ── Model Profiles ──────────────────────────────────────────────────────────
LLAMA_2_7B = ModelConfig(
name="llama_2_7b",
hidden_size=4096,
intermediate_size=11008,
vocab_size=32000,
num_attention_heads=32,
num_key_value_heads=32,
head_dim=128,
hidden_act="silu",
max_position_embeddings=4096,
)
LLAMA_3_8B = ModelConfig(
name="llama_3_8b",
hidden_size=4096,
intermediate_size=14336,
vocab_size=128256,
num_attention_heads=32,
num_key_value_heads=8,
head_dim=128,
hidden_act="silu",
max_position_embeddings=8192,
)
MODEL_REGISTRY: Dict[str, ModelConfig] = {
"llama_2_7b": LLAMA_2_7B,
"llama_3_8b": LLAMA_3_8B,
}
DEFAULT_MODEL_CONFIG = LLAMA_3_8B
def get_benchmark_model_config(model_name: Optional[str] = None) -> ModelConfig:
"""Resolve benchmark model config from name.
Returns the canonical model architecture profile (hidden_size, vocab_size,
dtype, etc.) for benchmark runs. Use this to obtain model attributes
when building benchmark tensors and shapes.
Args:
model_name: Registry key (e.g. ``llama_2_7b``, ``llama_3_8b``).
If None, returns ``DEFAULT_MODEL_CONFIG``.
"""
return MODEL_REGISTRY[model_name] if model_name else DEFAULT_MODEL_CONFIG
def estimate_kernel_peak_memory(probe_fn: Callable[[], torch.Tensor]) -> int:
"""Run a forward + backward probe to measure peak memory (bytes).
Call this with the *pure PyTorch* (e.g. huggingface) implementation --
that typically has the highest memory footprint and therefore gives a
safe upper-bound estimate. Returns the total peak bytes; divide by
num_tokens if you need bytes-per-token for :func:`compute_seq_len_sweep_config`.
The probe_fn performs setup and forward pass internally; cleanup is
automatic, so callers do not need to manage tensor/layer lifecycle.
Example::
peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe)
kernel_bpt = peak_bytes // num_tokens # if needed
Args:
probe_fn: Callable that performs setup, runs a forward pass, and
returns an output tensor suitable for ``.backward()``.
"""
device_str = infer_device()
torch_device_mod = getattr(torch, device_str)
gc.collect()
torch_device_mod.empty_cache()
torch_device_mod.memory.reset_peak_memory_stats()
y = probe_fn()
y.backward(torch.randn_like(y))
peak_bytes = torch_device_mod.max_memory_allocated()
del y
gc.collect()
torch_device_mod.empty_cache()
return max(1, peak_bytes)
def compute_seq_len_sweep_config(
model_cfg: ModelConfig,
kernel_bytes_per_token: Optional[int] = None,
memory_utilization: float = 0.4,
max_seq_len: Optional[int] = None,
max_batch_size: int = 32,
) -> SeqLenSweepConfig:
"""Compute safe batch_size and seq_len for sequence-length sweep (e.g. GEGLU).
Peak memory is estimated as
``batch_size * seq_len * kernel_bytes_per_token`` and is capped at
device memory * memory_utilization. Device memory is obtained
internally via :func:`~liger_kernel.utils.get_total_gpu_memory`.
Prefer obtaining *kernel_bytes_per_token* via
:func:`estimate_kernel_peak_memory` (divide by num_tokens) rather
than hardcoding an analytical estimate.
Args:
model_cfg: Model architecture config.
kernel_bytes_per_token: Peak memory **per token** (``batch * seq_len``
axis). Best obtained from :func:`estimate_kernel_peak_memory` / num_tokens.
Falls back to a conservative heuristic
(``hidden_size * dtype_bytes * 16``) when *None*.
memory_utilization: Fraction of total device memory to target (0 to 1).
Lower values are safer. Default ``0.4`` leaves headroom for
framework overhead and CUDA/NPU context.
max_seq_len: Hard upper bound for sequence length. Defaults to
``model_cfg.max_position_embeddings`` so the sweep never exceeds
the model's native context window.
max_batch_size: Hard upper bound for batch size.
"""
total_memory_gb = get_total_gpu_memory()
dtype_bytes = 2 if model_cfg.dtype in (torch.bfloat16, torch.float16) else 4
if kernel_bytes_per_token is None:
kernel_bytes_per_token = model_cfg.hidden_size * dtype_bytes * 16
if max_seq_len is None:
max_seq_len = model_cfg.max_position_embeddings
usable_bytes = total_memory_gb * (1024**3) * memory_utilization
max_tokens = max(1, int(usable_bytes / kernel_bytes_per_token))
seq_len = min(max_seq_len, max_tokens)
seq_len = 2 ** int(math.log2(seq_len)) if seq_len >= 1024 else 1024
batch_size = max(1, min(max_tokens // seq_len, max_batch_size))
return SeqLenSweepConfig(batch_size=batch_size, seq_len=seq_len)
def compute_hidden_size_sweep_config(
model_cfg: ModelConfig,
kernel_peak_bytes: int,
bt: int = 4096,
memory_utilization: float = 0.4,
max_hidden_size_multiplier: int = 4,
) -> HiddenSizeSweepConfig:
"""Compute safe max_hidden_size for hidden_size sweep (e.g. DyT).
For kernels with shape (BT, hidden_size) where BT is fixed and we sweep
hidden_size. Uses probe peak memory to derive max_hidden_size.
Device memory is obtained internally via :func:`~liger_kernel.utils.get_total_gpu_memory`.
Args:
model_cfg: Model config.
kernel_peak_bytes: Peak memory from probe (BT, model.hidden_size).
bt: Fixed BT dimension; must match the probe.
memory_utilization: Fraction of device memory to use.
max_hidden_size_multiplier: Cap max_hidden_size at model.hidden_size * this.
"""
total_memory_gb = get_total_gpu_memory()
usable_bytes = total_memory_gb * (1024**3) * memory_utilization
kernel_bpt = max(1, kernel_peak_bytes // bt)
max_hidden_size = min(
model_cfg.hidden_size * max_hidden_size_multiplier,
max(
model_cfg.hidden_size,
int(usable_bytes * model_cfg.hidden_size / (bt * kernel_bpt)),
),
)
max_hidden_size = max(1024, 2 ** int(math.log2(max_hidden_size)))
return HiddenSizeSweepConfig(bt=bt, max_hidden_size=max_hidden_size)
This diff is collapsed.
import os
import sys
import torch
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.utils import infer_device
device = infer_device()
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
#############################################################################
# Test the memory consumption of the linear fused cross entropy loss
#############################################################################
def bench_memory_fused_linear_orpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_orpo_loss import LigerLMHeadORPO
from test.chunked_loss.test_orpo_loss import TorchLMHeadORPO
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider
# Instantiate once and retrieve the first output only
torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device)
torch_fwd = lambda x, target, nll_target: torch_lm_head_orpo(x, target, nll_target)[0]
liger_fwd = lambda x, target, nll_target: liger_lm_head_orpo(x, target, nll_target)[0]
_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
nll_target = torch.randint(V, (B, T), dtype=torch.long, device=device)
def fwd():
if provider == "liger":
return liger_fwd(_input, target, nll_target)
elif provider == "huggingface":
return torch_fwd(_input, target, nll_target)
def full():
y = fwd()
y.backward()
mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
# #############################################################################
# # Test the speed of the fused linear cross entropy loss
# #############################################################################
def bench_speed_fused_linear_orpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_orpo_loss import LigerLMHeadORPO
from test.chunked_loss.test_orpo_loss import TorchLMHeadORPO
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider
mode = input.kernel_operation_mode
# Instantiate once and retrieve the first output only
torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device)
torch_fwd = lambda x, target, nll_target: torch_lm_head_orpo(x, target, nll_target)[0]
liger_fwd = lambda x, target, nll_target: liger_lm_head_orpo(x, target, nll_target)[0]
_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
nll_target = torch.randint(V, (B, T), dtype=torch.long, device=device)
def fwd():
if provider == "liger":
return liger_fwd(_input, target, nll_target)
elif provider == "huggingface":
return torch_fwd(_input, target, nll_target)
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=[_input],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":
def full():
y = fwd()
y.backward()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
rep=100,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_configs = {
"kernel_name": "fused_linear_orpo_loss",
"x_name": "B",
"x_label": "B",
"x_values": [2**i for i in range(1, 5)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"T": 1024,
"H": 4096,
"V": 128256,
"mode": "forward",
"dtype": torch.bfloat16,
}
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_fused_linear_orpo_loss,
kernel_operation_modes=["forward", "full", "backward"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_fused_linear_orpo_loss,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment