Commit 9b0e3a30 authored by cmx's avatar cmx
Browse files

first commit

parent fe5cd1fc
Pipeline #3450 failed with stages
in 0 seconds
import math
import torch
from benchmark_model_configs import compute_seq_len_sweep_config
from benchmark_model_configs import estimate_kernel_peak_memory
from benchmark_model_configs import get_benchmark_model_config
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaMLP
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from utils import run_memory_benchmark
from utils import run_speed_benchmark
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.utils import infer_device
device = infer_device()
def _setup_geglu(input: SingleBenchmarkRunInput):
"""Create input tensor and GEGLU layer from benchmark config."""
cfg = input.extra_benchmark_config
llama_config = LlamaConfig(
hidden_size=cfg["hidden_size"],
intermediate_size=cfg["intermediate_size"],
hidden_act=cfg["hidden_act"],
)
x = torch.randn(
cfg["bsz"],
input.x,
cfg["hidden_size"],
device=device,
dtype=cfg["dtype"],
requires_grad=True,
)
if input.kernel_provider == "liger":
layer = LigerGEGLUMLP(config=llama_config).to(device).to(cfg["dtype"])
elif input.kernel_provider == "huggingface":
layer = LlamaMLP(config=llama_config).to(device).to(cfg["dtype"])
else:
raise ValueError(f"Invalid provider: {input.kernel_provider} for GEGLU")
return x, layer
def bench_speed_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
x, layer = _setup_geglu(input)
return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x])
def bench_memory_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
x, layer = _setup_geglu(input)
return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode)
if __name__ == "__main__":
args = parse_benchmark_script_args()
model = get_benchmark_model_config(args.model)
probe_seq_len = 1024
def _probe():
probe_input = SingleBenchmarkRunInput(
x=probe_seq_len,
kernel_provider="huggingface",
extra_benchmark_config={
"bsz": 1,
"hidden_size": model.hidden_size,
"intermediate_size": model.intermediate_size,
"hidden_act": "gelu_pytorch_tanh",
"dtype": model.dtype,
},
)
x, layer = _setup_geglu(probe_input)
return layer(x)
peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe)
kernel_bpt = peak_bytes // probe_seq_len
config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt)
common_configs = {
"kernel_name": "geglu",
"x_name": "T",
"x_label": "sequence length",
"x_values": [2**i for i in range(10, int(math.log2(config.seq_len)) + 1)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"bsz": config.batch_size,
"hidden_size": model.hidden_size,
"intermediate_size": model.intermediate_size,
"hidden_act": "gelu_pytorch_tanh",
"dtype": model.dtype,
}
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_geglu,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_geglu,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
import torch
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.transformers.group_norm import LigerGroupNorm
from liger_kernel.utils import infer_device
device = infer_device()
def bench_speed_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
C = input.x
provider = input.kernel_provider
mode = input.kernel_operation_mode
extra_benchmark_config = input.extra_benchmark_config
M = extra_benchmark_config["M"]
H = extra_benchmark_config["H"]
channels_per_group = extra_benchmark_config["channels_per_group"]
eps = extra_benchmark_config["eps"]
dtype = extra_benchmark_config["dtype"]
x_shape = (M, C, H)
triton_ln = LigerGroupNorm(num_channels=C, num_groups=C // channels_per_group, eps=eps).to(device)
torch_ln = torch.nn.GroupNorm(num_groups=C // channels_per_group, num_channels=C, eps=eps).to(device)
x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)
def y_fwd():
if provider == "liger":
return triton_ln(x)
if provider == "huggingface":
return torch_ln(x)
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500)
elif mode == "backward":
y = y_fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(dy, retain_graph=True),
quantiles=QUANTILES,
grad_to_none=[x],
rep=500,
)
elif mode == "full":
def full():
y = y_fwd()
y.backward(dy, retain_graph=True)
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=[x], rep=500)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
def bench_memory_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
C = input.x
provider = input.kernel_provider
extra_benchmark_config = input.extra_benchmark_config
M = extra_benchmark_config["M"]
H = extra_benchmark_config["H"]
channels_per_group = extra_benchmark_config["channels_per_group"]
eps = extra_benchmark_config["eps"]
dtype = extra_benchmark_config["dtype"]
x_shape = (M, C, H)
triton_ln = LigerGroupNorm(num_channels=C, num_groups=C // channels_per_group, eps=eps).to(device)
torch_ln = torch.nn.GroupNorm(num_groups=C // channels_per_group, num_channels=C, eps=eps).to(device)
x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)
def y_fwd():
if provider == "liger":
return triton_ln(x)
if provider == "huggingface":
return torch_ln(x)
def full():
y = y_fwd()
y.backward(dy, retain_graph=True)
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_configs = {
"kernel_name": "group_norm",
"x_name": "C",
"x_label": "num_channels",
"x_values": [2**i for i in range(5, 12)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"M": 128,
"H": 512,
"channels_per_group": 4,
"dtype": torch.float32,
"eps": 1e-6,
}
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_group_norm,
kernel_operation_modes=["forward", "full", "backward"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_group_norm,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
import os
import sys
import torch
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.utils import infer_device
device = infer_device()
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
#############################################################################
# Test the memory consumption of the linear fused GRPO loss
#############################################################################
def bench_memory_fused_linear_grpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_grpo_loss import LigerLMHeadGRPO
from test.chunked_loss.test_grpo_loss import TorchLMHeadGRPO
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
importance_sampling_level = input.extra_benchmark_config["importance_sampling_level"]
provider = input.kernel_provider
# Instantiate once and retrieve the first output only
torch_lm_head_grpo = TorchLMHeadGRPO(H=H, V=V, dtype=dtype, importance_sampling_level=importance_sampling_level).to(
device
)
liger_lm_head_grpo = LigerLMHeadGRPO(H=H, V=V, dtype=dtype, importance_sampling_level=importance_sampling_level).to(
device
)
# Create inputs
_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
selected_token_ids = torch.randint(0, V, (B, T), dtype=torch.long, device=device)
attention_mask = torch.ones(B, T, device=device)
advantages = torch.randn(B, dtype=dtype, device=device)
ref_input = torch.randn(B, T, H, dtype=dtype, device=device)
torch_fwd = lambda: torch_lm_head_grpo(_input, selected_token_ids, attention_mask, advantages, ref_input=ref_input)[
0
]
liger_fwd = lambda: liger_lm_head_grpo(_input, selected_token_ids, attention_mask, advantages, ref_input=ref_input)[
0
]
def fwd():
if provider == "liger":
return liger_fwd()
elif provider == "torch":
return torch_fwd()
def full():
y = fwd()
y.backward()
mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
#############################################################################
# Test the speed of the fused linear GRPO loss
#############################################################################
def bench_speed_fused_linear_grpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_grpo_loss import LigerLMHeadGRPO
from test.chunked_loss.test_grpo_loss import TorchLMHeadGRPO
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
importance_sampling_level = input.extra_benchmark_config["importance_sampling_level"]
provider = input.kernel_provider
mode = input.kernel_operation_mode
# Instantiate once and retrieve the first output only
torch_lm_head_grpo = TorchLMHeadGRPO(H=H, V=V, dtype=dtype, importance_sampling_level=importance_sampling_level).to(
device
)
liger_lm_head_grpo = LigerLMHeadGRPO(H=H, V=V, dtype=dtype, importance_sampling_level=importance_sampling_level).to(
device
)
# Create inputs
_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
selected_token_ids = torch.randint(0, V, (B, T), dtype=torch.long, device=device)
attention_mask = torch.ones(B, T, device=device)
advantages = torch.randn(B, dtype=dtype, device=device)
ref_input = torch.randn(B, T, H, dtype=dtype, device=device)
torch_fwd = lambda: torch_lm_head_grpo(_input, selected_token_ids, attention_mask, advantages, ref_input=ref_input)[
0
]
liger_fwd = lambda: liger_lm_head_grpo(_input, selected_token_ids, attention_mask, advantages, ref_input=ref_input)[
0
]
def fwd():
if provider == "liger":
return liger_fwd()
elif provider == "torch":
return torch_fwd()
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=[_input],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":
def full():
y = fwd()
y.backward()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
rep=100,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
# Benchmark token-level importance sampling (original GRPO)
token_configs = {
"kernel_name": "fused_linear_grpo_loss_token",
"x_name": "B",
"x_label": "B",
"x_values": [2**i for i in range(1, 5)],
"kernel_providers": ["liger", "torch"],
"extra_benchmark_configs": [
{
"T": 1024,
"H": 4096,
"V": 128256,
"importance_sampling_level": "token",
"dtype": torch.bfloat16,
}
],
"overwrite": args.overwrite,
}
# Benchmark sequence-level importance sampling (GSPO)
sequence_configs = {
"kernel_name": "fused_linear_grpo_loss_sequence",
"x_name": "B",
"x_label": "B",
"x_values": [2**i for i in range(1, 5)],
"kernel_providers": ["liger", "torch"],
"extra_benchmark_configs": [
{
"T": 1024,
"H": 4096,
"V": 128256,
"importance_sampling_level": "sequence",
"dtype": torch.bfloat16,
}
],
"overwrite": args.overwrite,
}
# Run benchmarks for token-level (GRPO)
print("Benchmarking GRPO (token-level importance sampling)...")
run_benchmarks(
bench_test_fn=bench_speed_fused_linear_grpo_loss,
kernel_operation_modes=["forward", "full", "backward"],
metric_name="speed",
metric_unit="ms",
**token_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_fused_linear_grpo_loss,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**token_configs,
)
# Run benchmarks for sequence-level (GSPO)
print("Benchmarking GSPO (sequence-level importance sampling)...")
run_benchmarks(
bench_test_fn=bench_speed_fused_linear_grpo_loss,
kernel_operation_modes=["forward", "full", "backward"],
metric_name="speed",
metric_unit="ms",
**sequence_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_fused_linear_grpo_loss,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**sequence_configs,
)
import torch
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.transformers.jsd import LigerJSD
from liger_kernel.utils import get_total_gpu_memory
from liger_kernel.utils import infer_device
device = infer_device()
class TorchJSD(torch.nn.Module):
def __init__(
self,
beta: float = 0.5,
ignore_index: int = -100,
dtype: torch.dtype = torch.float,
):
super(TorchJSD, self).__init__()
self.kl = torch.nn.KLDivLoss(reduction="none", log_target=True)
self.beta = beta
self.ignore_index = ignore_index
self.dtype = dtype
def forward(
self,
log_q: torch.Tensor, # input
log_p: torch.Tensor, # target
label=None,
):
log_p, log_q = log_p.to(torch.float), log_q.to(torch.float)
log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1))
m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta)
loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + (1 - self.beta) * self.kl(
torch.log(m), log_q
).sum(dim=-1)
if label is not None:
loss = torch.where(label != self.ignore_index, loss, 0.0)
n_non_ignore = (label != self.ignore_index).sum().item()
if n_non_ignore == 0:
loss = 0.0
else:
loss = (loss / n_non_ignore).sum()
else:
loss = (loss / log_q.shape[0]).sum()
return loss.to(self.dtype)
def bench_speed_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
V = input.x
B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
torch_jsd = TorchJSD()
liger_jsd = LigerJSD()
_input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(dim=-1)
target = torch.randn(B * T, V, device=device).log_softmax(dim=-1)
def fwd():
if input.kernel_provider == "liger":
return liger_jsd(_input, target)
else:
return torch_jsd(_input, target)
if input.kernel_operation_mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100)
elif input.kernel_operation_mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
quantiles=QUANTILES,
grad_to_none=[_input],
rep=100,
)
elif input.kernel_operation_mode == "full":
def full():
y = fwd()
y.backward(retain_graph=True)
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
def bench_memory_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
torch_jsd = TorchJSD()
liger_jsd = LigerJSD()
V = input.x
B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
_input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(dim=-1)
target = torch.randn(B * T, V, device=device).log_softmax(dim=-1)
def fwd():
if input.kernel_provider == "liger":
return liger_jsd(_input, target)
else:
return torch_jsd(_input, target)
def full():
y = fwd()
y.backward(retain_graph=True)
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
gpu_memory_gbs = get_total_gpu_memory()
# We know that the full test will require 54GBs for vocab size 2^17 on torch
if gpu_memory_gbs >= 54:
x_max = 17
else:
x_max = 16
common_args = {
"kernel_name": "jsd",
"x_name": "V",
"x_label": "vocab size",
"x_values": [2**i for i in range(12, x_max + 1)],
"kernel_providers": ["liger", "torch"],
"extra_benchmark_configs": [{"B": 4, "T": 2048}],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_memory_jsd,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_args,
)
run_benchmarks(
bench_test_fn=bench_speed_jsd,
kernel_operation_modes=["forward", "backward", "full"],
metric_name="speed",
metric_unit="ms",
**common_args,
)
import torch
import torch.nn as nn
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.transformers.kl_div import LigerKLDIVLoss
from liger_kernel.utils import infer_device
device = infer_device()
S, E = 12, 18
def bench_speed_kldiv(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
reduction = "batchmean"
V = input.x
B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
torch_kl_div = nn.KLDivLoss(reduction=reduction)
liger_kl_div = LigerKLDIVLoss(reduction=reduction)
_input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(dim=-1)
target = torch.randn(B * T, V, device=device).softmax(dim=-1)
def fwd():
if input.kernel_provider == "liger":
return liger_kl_div(_input, target)
else:
return torch_kl_div(_input, target)
if input.kernel_operation_mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100)
elif input.kernel_operation_mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
quantiles=QUANTILES,
grad_to_none=[_input],
rep=100,
)
elif input.kernel_operation_mode == "full":
def full():
y = fwd()
y.backward(retain_graph=True)
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
def bench_memory_kldiv(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
reduction = "batchmean"
torch_kl_div = nn.KLDivLoss(reduction=reduction)
liger_kl_div = LigerKLDIVLoss(reduction=reduction)
V = input.x
B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
_input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(dim=-1)
target = torch.randn(B * T, V, device=device).softmax(dim=-1)
def fwd():
if input.kernel_provider == "liger":
return liger_kl_div(_input, target)
else:
return torch_kl_div(_input, target)
def full():
y = fwd()
y.backward(retain_graph=True)
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_args = {
"kernel_name": "kl_div",
"x_name": "V",
"x_label": "vocab size",
"x_values": [2**i for i in range(12, 18)],
"kernel_providers": ["liger", "torch"],
"extra_benchmark_configs": [{"B": 8, "T": 512}],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_memory_kldiv,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_args,
)
run_benchmarks(
bench_test_fn=bench_speed_kldiv,
kernel_operation_modes=["forward", "backward", "full"],
metric_name="speed",
metric_unit="ms",
**common_args,
)
import os
import sys
import torch
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.chunked_loss import LigerFusedLinearKTOLoss
from liger_kernel.utils import infer_device
device = infer_device()
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
class TorchLMHeadKTO(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
use_bias: bool = False,
use_ref_bias: bool = False,
ignore_index: int = -100,
beta: float = 0.1,
):
from test.chunked_loss.test_kto_loss import HFKTOLoss
super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=use_bias, dtype=dtype)
self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=use_ref_bias, dtype=dtype)
self.KTO_loss = HFKTOLoss(
ignore_index=ignore_index,
beta=beta,
use_ref_model=True,
).get_batch_loss_metrics
def forward(self, x, ref_x, y, preference_labels, kl=None):
return self.KTO_loss(
weight=self.lin.weight,
_input=x,
target=y,
bias=self.lin.bias,
ref_input=ref_x,
ref_weight=self.ref_lin.weight,
ref_bias=self.ref_lin.bias,
preference_labels=preference_labels,
kl=kl,
)
class LigerLMHeadKTO(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
use_bias: bool = False,
use_ref_bias: bool = False,
ignore_index: int = -100,
beta: float = 0.1,
):
super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=use_bias, dtype=dtype)
self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=use_ref_bias, dtype=dtype)
self.KTO_loss = LigerFusedLinearKTOLoss(
ignore_index=ignore_index,
beta=beta,
use_ref_model=True,
)
def forward(self, x, ref_x, y, preference_labels, kl=None):
return self.KTO_loss(
_input=x,
lin_weight=self.lin.weight,
target=y,
preference_labels=preference_labels,
bias=self.lin.bias,
ref_input=ref_x,
ref_weight=self.ref_lin.weight,
ref_bias=self.ref_lin.bias,
kl=kl,
)
def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
bias = input.extra_benchmark_config["bias"]
beta = input.extra_benchmark_config["beta"]
ignore_index = input.extra_benchmark_config["ignore_index"]
provider = input.kernel_provider
torch_kto_loss = TorchLMHeadKTO(
H=H,
V=V,
dtype=dtype,
use_bias=bias,
use_ref_bias=bias,
ignore_index=ignore_index,
beta=beta,
).to(device)
liger_kto_loss = LigerLMHeadKTO(
H=H,
V=V,
dtype=dtype,
use_bias=bias,
use_ref_bias=bias,
ignore_index=ignore_index,
beta=beta,
).to(device)
# Input shape: [B, T, H]
_input = torch.randn(B, T, H, device=device, dtype=dtype)
# Target shape: [B, T]
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
# Preference labels shape: [B]
# Create binary preference labels (0 or 1) for each sequence in the batch
# Used to indicate preferred sequences (1) vs non-preferred sequences (0)
preference_labels = torch.randint(2, (B,), dtype=torch.bool, device=device)
# Precomputed KL divergence between policy and reference distributions
kl = torch.randn(1, device=device, dtype=dtype)
# Add ignore_index tokens to simulate padding
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
target.view(-1)[indices_to_assign] = ignore_index
# Add ref_x with the same shape as _input
ref_input = torch.randn(B, T, H, device=device, dtype=dtype)
def fwd():
if provider == "liger":
return liger_kto_loss(
x=_input,
ref_x=ref_input,
y=target,
preference_labels=preference_labels,
kl=kl,
)[0]
elif provider == "huggingface":
return torch_kto_loss(
x=_input,
ref_x=ref_input,
y=target,
preference_labels=preference_labels,
kl=kl,
)[0]
def full():
y = fwd()
y.backward()
mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
bias = input.extra_benchmark_config["bias"]
beta = input.extra_benchmark_config["beta"]
ignore_index = input.extra_benchmark_config["ignore_index"]
provider = input.kernel_provider
mode = input.kernel_operation_mode
torch_kto_loss = TorchLMHeadKTO(
H=H,
V=V,
dtype=dtype,
beta=beta,
ignore_index=ignore_index,
use_bias=bias,
).to(device)
liger_kto_loss = LigerLMHeadKTO(
H=H,
V=V,
dtype=dtype,
beta=beta,
ignore_index=ignore_index,
use_bias=bias,
).to(device)
# Input shape: [B, T, H]
_input = torch.randn(B, T, H, device=device, dtype=dtype)
# Target shape: [B, T]
target = torch.randint(V, (B, T), device=device, dtype=torch.long)
# Preference labels shape: [B]
# Create binary preference labels (0 or 1) for each sequence in the batch
# Used to indicate preferred sequences (1) vs non-preferred sequences (0)
preference_labels = torch.randint(2, (B,), dtype=torch.bool, device=device)
# Precomputed KL divergence between policy and reference distributions
kl = torch.randn(1, device=device, dtype=dtype)
# Add ignore_index tokens
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
target.view(-1)[indices_to_assign] = ignore_index
# Add ref_x with the same shape as _input
ref_input = torch.randn(B, T, H, device=device, dtype=dtype)
def fwd():
if provider == "liger":
return liger_kto_loss(
x=_input,
ref_x=ref_input,
y=target,
preference_labels=preference_labels,
kl=kl,
)[0]
elif provider == "huggingface":
return torch_kto_loss(
x=_input,
ref_x=ref_input,
y=target,
preference_labels=preference_labels,
kl=kl,
)[0]
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=[_input],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":
def full():
y = fwd()
y.backward()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
rep=100,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_configs = {
"kernel_name": "kto_loss",
"x_name": "B",
"x_label": "Batch Size (B)",
"x_values": [2**i for i in range(1, 6)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"T": 512,
"H": 1024,
"V": 128256,
"mode": "forward",
"dtype": torch.bfloat16,
"bias": True,
"beta": 0.1,
"ignore_index": 42,
}
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_kto_loss,
kernel_operation_modes=["forward", "backward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_kto_loss,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
import torch
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.utils import infer_device
device = infer_device()
def bench_speed_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
N = input.x
provider = input.kernel_provider
mode = input.kernel_operation_mode
extra_benchmark_config = input.extra_benchmark_config
M = extra_benchmark_config["M"]
eps = extra_benchmark_config["eps"]
dtype = extra_benchmark_config["dtype"]
x_shape = (M, N)
triton_ln = LigerLayerNorm(hidden_size=N).to(device)
torch_ln = torch.nn.LayerNorm(N, eps=eps).to(device)
x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)
def y_fwd():
if provider == "liger":
return triton_ln(x)
if provider == "huggingface":
return torch_ln(x)
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500)
elif mode == "backward":
y = y_fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(dy, retain_graph=True),
quantiles=QUANTILES,
grad_to_none=[x],
rep=500,
)
elif mode == "full":
def full():
y = y_fwd()
y.backward(dy, retain_graph=True)
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=[x], rep=500)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
def bench_memory_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
N = input.x
provider = input.kernel_provider
dtype = input.extra_benchmark_config["dtype"]
M = input.extra_benchmark_config["M"]
eps = input.extra_benchmark_config["eps"]
x_shape = (M, N)
triton_ln = LigerLayerNorm(hidden_size=N).to(device)
torch_ln = torch.nn.LayerNorm(N, eps=eps).to(device)
x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)
def y_fwd():
if provider == "liger":
return triton_ln(x)
if provider == "huggingface":
return torch_ln(x)
def full():
y = y_fwd()
y.backward(dy, retain_graph=True)
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_configs = {
"kernel_name": "layer_norm",
"x_name": "N",
"x_label": "hidden size",
"x_values": [2**i for i in range(10, 15)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [{"M": 4096, "dtype": torch.float32, "eps": 1e-6}],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_layer_norm,
kernel_operation_modes=["forward", "backward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_layer_norm,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
import torch
import triton
from transformers.models.llama4.configuration_llama4 import Llama4TextConfig
from transformers.models.llama4.modeling_llama4 import Llama4TextRotaryEmbedding
from transformers.models.llama4.modeling_llama4 import apply_rotary_emb
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.transformers.llama4_rope import liger_llama4_text_rotary_pos_emb
from liger_kernel.utils import infer_device
from liger_kernel.utils import transformers_version_dispatch
device = infer_device()
def bench_speed_llama4_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
provider = input.kernel_provider
mode = input.kernel_operation_mode
extra_benchmark_config = input.extra_benchmark_config
num_q_heads = extra_benchmark_config["num_q_heads"]
num_kv_heads = extra_benchmark_config["num_kv_heads"]
dtype = extra_benchmark_config["dtype"]
# x can be either hidden_size or seq_len
hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x
seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x
head_dim = hidden_size // num_q_heads
# Create Llama4TextConfig for the rotary embedding
config = Llama4TextConfig(
hidden_size=hidden_size,
num_attention_heads=num_q_heads,
num_key_value_heads=num_kv_heads,
head_dim=head_dim,
max_position_embeddings=seq_len,
)
rotary_emb = transformers_version_dispatch(
"4.48.0",
Llama4TextRotaryEmbedding,
Llama4TextRotaryEmbedding,
before_kwargs={"config": config, "device": device},
after_kwargs={"config": config, "device": device},
)
q = torch.randn(
(1, seq_len, num_q_heads, head_dim),
device=device,
requires_grad=True,
dtype=dtype,
)
k = torch.randn(
(1, seq_len, num_kv_heads, head_dim),
device=device,
requires_grad=True,
dtype=dtype,
)
dq, dk = (
torch.randn_like(q, device=device, dtype=dtype),
torch.randn_like(k, device=device),
)
pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
freqs_cis = rotary_emb(q, pos_ids)
def fwd():
if provider == "liger":
return liger_llama4_text_rotary_pos_emb(q, k, freqs_cis)
elif provider == "huggingface":
return apply_rotary_emb(q, k, freqs_cis)
else:
raise ValueError(f"Invalid provider: {provider} for Llama4 RoPE embedding")
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
grad_to_none=[q, k],
rep=400,
quantiles=QUANTILES,
)
elif mode == "backward":
q_out, k_out = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True),
grad_to_none=[q, k],
rep=400,
quantiles=QUANTILES,
)
elif mode == "full":
def full():
q_out, k_out = fwd()
torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True)
ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
grad_to_none=[q, k],
rep=400,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
def bench_memory_llama4_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
provider = input.kernel_provider
extra_benchmark_config = input.extra_benchmark_config
num_q_heads = extra_benchmark_config["num_q_heads"]
num_kv_heads = extra_benchmark_config["num_kv_heads"]
dtype = extra_benchmark_config["dtype"]
# x can be either hidden_size or seq_len
hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x
seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x
head_dim = hidden_size // num_q_heads
# Create Llama4TextConfig for the rotary embedding
config = Llama4TextConfig(
hidden_size=hidden_size,
num_attention_heads=num_q_heads,
num_key_value_heads=num_kv_heads,
head_dim=head_dim,
max_position_embeddings=seq_len,
)
rotary_emb = transformers_version_dispatch(
"4.48.0",
Llama4TextRotaryEmbedding,
Llama4TextRotaryEmbedding,
before_kwargs={"config": config, "device": device},
after_kwargs={"config": config, "device": device},
)
q = torch.randn(
(1, seq_len, num_q_heads, head_dim),
device=device,
requires_grad=True,
dtype=dtype,
)
k = torch.randn(
(1, seq_len, num_kv_heads, head_dim),
device=device,
requires_grad=True,
dtype=dtype,
)
dq, dk = (
torch.randn_like(q, device=device, dtype=dtype),
torch.randn_like(k, device=device),
)
pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
freqs_cis = rotary_emb(q, pos_ids)
def full():
if provider == "liger":
q_out, k_out = liger_llama4_text_rotary_pos_emb(q, k, freqs_cis)
else:
q_out, k_out = apply_rotary_emb(q, k, freqs_cis)
torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True)
mem_50, mem_20, mem_80 = _test_memory(
full,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_configs_varying_hidden_size = {
"kernel_name": "llama4_rope",
"x_name": "H",
"x_label": "hidden size",
"x_values": [32 * (2**i) for i in range(4, 10, 2)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"dtype": torch.bfloat16,
"seq_len": 2048,
"num_q_heads": 32,
"num_kv_heads": 8,
}
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_llama4_rope,
kernel_operation_modes=["forward", "backward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs_varying_hidden_size,
)
run_benchmarks(
bench_test_fn=bench_memory_llama4_rope,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs_varying_hidden_size,
)
common_configs_varying_seq_len = {
"kernel_name": "llama4_rope",
"x_name": "T",
"x_label": "sequence length",
"x_values": [2**i for i in range(10, 15)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"dtype": torch.bfloat16,
"hidden_size": 8192,
"num_q_heads": 32,
"num_kv_heads": 8,
}
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_llama4_rope,
kernel_operation_modes=["forward", "backward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs_varying_seq_len,
)
run_benchmarks(
bench_test_fn=bench_memory_llama4_rope,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs_varying_seq_len,
)
import os
import sys
import torch
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.transformers.functional import liger_mhc_coeffs
from liger_kernel.transformers.functional import liger_mhc_post_res
from liger_kernel.transformers.functional import liger_mhc_pre
from liger_kernel.utils import infer_device
device = infer_device()
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
def bench_speed_mhc(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
from test.transformers.test_mhc import mhc_coeffs_ref
T = input.x
B = input.extra_benchmark_config["B"]
HC = input.extra_benchmark_config["HC"]
C = input.extra_benchmark_config["C"]
sub_kernel = input.extra_benchmark_config["sub_kernel"]
tmax = input.extra_benchmark_config["tmax"]
rms_eps = input.extra_benchmark_config["rms_eps"]
pre_eps = input.extra_benchmark_config["pre_eps"]
sinkhorn_eps = input.extra_benchmark_config["sinkhorn_eps"]
post_mult = input.extra_benchmark_config["post_mult"]
provider = input.kernel_provider
mode = input.kernel_operation_mode
coeffs_cfg = dict(tmax=tmax, rms_eps=rms_eps, pre_eps=pre_eps, sinkhorn_eps=sinkhorn_eps, post_mult=post_mult)
need_grad = mode in ("backward", "full")
x = torch.randn(B, T, HC, C, device=device, dtype=torch.bfloat16, requires_grad=need_grad)
K, M = HC * C, HC * HC + 2 * HC
phi = (torch.randn(K, M, device=device, dtype=torch.bfloat16) * 0.02).requires_grad_(need_grad)
b_param = torch.zeros(M, device=device, dtype=torch.float32, requires_grad=need_grad)
alpha_pre = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=need_grad)
alpha_post = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=need_grad)
alpha_res = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=need_grad)
grad_to_none = [x, phi, b_param, alpha_pre, alpha_post, alpha_res] if need_grad else None
if sub_kernel == "coeffs":
def fwd():
if provider == "liger":
return liger_mhc_coeffs(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg)
return mhc_coeffs_ref(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg)
def fwd_loss():
h_pre, h_post, h_res = fwd()
return h_pre.square().mean() + h_post.square().mean() + h_res.square().mean()
elif sub_kernel == "pre":
with torch.no_grad():
h_pre_c, _, _ = liger_mhc_coeffs(
x.detach(),
phi.detach(),
b_param.detach(),
alpha_pre.detach(),
alpha_post.detach(),
alpha_res.detach(),
**coeffs_cfg,
)
h_pre_c.requires_grad_(need_grad)
grad_to_none = [x, h_pre_c] if need_grad else None
def fwd():
if provider == "liger":
return liger_mhc_pre(x, h_pre_c)
return (x.float() * h_pre_c.unsqueeze(-1)).sum(dim=-2)
def fwd_loss():
return fwd().square().mean()
elif sub_kernel == "post_res":
with torch.no_grad():
_, h_post_c, h_res_c = liger_mhc_coeffs(
x.detach(),
phi.detach(),
b_param.detach(),
alpha_pre.detach(),
alpha_post.detach(),
alpha_res.detach(),
**coeffs_cfg,
)
h_post_c.requires_grad_(need_grad)
h_res_c.requires_grad_(need_grad)
f_out = torch.randn(B, T, C, device=device, dtype=torch.bfloat16, requires_grad=need_grad)
grad_to_none = [x, f_out, h_post_c, h_res_c] if need_grad else None
def fwd():
if provider == "liger":
return liger_mhc_post_res(x, f_out, h_post_c, h_res_c)
return torch.einsum("...oi,...ic->...oc", h_res_c, x.float()) + h_post_c.unsqueeze(
-1
) * f_out.float().unsqueeze(-2)
def fwd_loss():
return fwd().square().mean()
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES)
elif mode == "backward":
y = fwd_loss()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=grad_to_none,
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":
def full():
y = fwd_loss()
y.backward()
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=grad_to_none, rep=100, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80)
def bench_memory_mhc(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
from test.transformers.test_mhc import mhc_coeffs_ref
T = input.x
B = input.extra_benchmark_config["B"]
HC = input.extra_benchmark_config["HC"]
C = input.extra_benchmark_config["C"]
sub_kernel = input.extra_benchmark_config["sub_kernel"]
tmax = input.extra_benchmark_config["tmax"]
rms_eps = input.extra_benchmark_config["rms_eps"]
pre_eps = input.extra_benchmark_config["pre_eps"]
sinkhorn_eps = input.extra_benchmark_config["sinkhorn_eps"]
post_mult = input.extra_benchmark_config["post_mult"]
provider = input.kernel_provider
coeffs_cfg = dict(tmax=tmax, rms_eps=rms_eps, pre_eps=pre_eps, sinkhorn_eps=sinkhorn_eps, post_mult=post_mult)
x = torch.randn(B, T, HC, C, device=device, dtype=torch.bfloat16, requires_grad=True)
K, M = HC * C, HC * HC + 2 * HC
phi = (torch.randn(K, M, device=device, dtype=torch.bfloat16) * 0.02).requires_grad_(True)
b_param = torch.zeros(M, device=device, dtype=torch.float32, requires_grad=True)
alpha_pre = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True)
alpha_post = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True)
alpha_res = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True)
if sub_kernel == "coeffs":
def full():
if provider == "liger":
hp, hpo, hr = liger_mhc_coeffs(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg)
else:
hp, hpo, hr = mhc_coeffs_ref(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg)
(hp.square().mean() + hpo.square().mean() + hr.square().mean()).backward()
elif sub_kernel == "pre":
with torch.no_grad():
h_pre_c, _, _ = liger_mhc_coeffs(
x.detach(),
phi.detach(),
b_param.detach(),
alpha_pre.detach(),
alpha_post.detach(),
alpha_res.detach(),
**coeffs_cfg,
)
h_pre_c.requires_grad_(True)
def full():
if provider == "liger":
out = liger_mhc_pre(x, h_pre_c)
else:
out = (x.float() * h_pre_c.unsqueeze(-1)).sum(dim=-2)
out.square().mean().backward()
elif sub_kernel == "post_res":
with torch.no_grad():
_, h_post_c, h_res_c = liger_mhc_coeffs(
x.detach(),
phi.detach(),
b_param.detach(),
alpha_pre.detach(),
alpha_post.detach(),
alpha_res.detach(),
**coeffs_cfg,
)
h_post_c.requires_grad_(True)
h_res_c.requires_grad_(True)
f_out = torch.randn(B, T, C, device=device, dtype=torch.bfloat16, requires_grad=True)
def full():
if provider == "liger":
out = liger_mhc_post_res(x, f_out, h_post_c, h_res_c)
else:
out = torch.einsum("...oi,...ic->...oc", h_res_c, x.float()) + h_post_c.unsqueeze(
-1
) * f_out.float().unsqueeze(-2)
out.square().mean().backward()
mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80)
if __name__ == "__main__":
args = parse_benchmark_script_args()
for sub_kernel in ["coeffs", "pre", "post_res"]:
common_configs = {
"kernel_name": f"mhc_{sub_kernel}",
"x_name": "T",
"x_label": "Sequence Length (T)",
"x_values": [2**i for i in range(7, 12)],
"kernel_providers": ["liger", "torch"],
"extra_benchmark_configs": [
{
"B": 4,
"HC": 4,
"C": 4096,
"tmax": 20,
"rms_eps": 1e-6,
"pre_eps": 0.0,
"sinkhorn_eps": 1e-6,
"post_mult": 2.0,
"sub_kernel": sub_kernel,
}
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_mhc,
kernel_operation_modes=["forward", "backward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_mhc,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.transformers.mhc import LigerMHC
from liger_kernel.utils import infer_device
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
device = infer_device()
class RMSNorm(nn.Module):
def __init__(self, hidden_size: int, *, eps: float, dtype: torch.dtype, device: str):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device))
def forward(self, x: torch.Tensor) -> torch.Tensor:
var = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(var + self.eps)
return x * self.weight
def _build_rope_cache(seq_len: int, head_dim: int, *, device: torch.device, dtype: torch.dtype):
inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2, device=device, dtype=torch.float32) / head_dim))
positions = torch.arange(seq_len, device=device, dtype=torch.float32)
freqs = torch.einsum("i,j->ij", positions, inv_freq)
cos = freqs.cos().to(dtype)
sin = freqs.sin().to(dtype)
return cos, sin
def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
x1 = x[..., ::2]
x2 = x[..., 1::2]
cos = cos[None, None, :, :]
sin = sin[None, None, :, :]
return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
class MiniLlamaAttention(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, *, dtype: torch.dtype, device: str):
super().__init__()
assert hidden_size % num_heads == 0
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
assert self.head_dim % 2 == 0, "head_dim must be even for RoPE"
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False, dtype=dtype, device=device)
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False, dtype=dtype, device=device)
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False, dtype=dtype, device=device)
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
bsz, seq_len, _ = x.shape
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
q = q.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
cos, sin = _build_rope_cache(seq_len, self.head_dim, device=x.device, dtype=q.dtype)
q = _apply_rope(q, cos, sin)
k = _apply_rope(k, cos, sin)
attn = F.scaled_dot_product_attention(q, k, v, is_causal=True)
attn = attn.transpose(1, 2).contiguous().view(bsz, seq_len, self.hidden_size)
return self.o_proj(attn)
class MiniLlamaMLP(nn.Module):
def __init__(self, hidden_size: int, intermediate_mult: int, *, dtype: torch.dtype, device: str):
super().__init__()
intermediate_size = hidden_size * intermediate_mult
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class AttentionBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, *, dtype: torch.dtype, device: str):
super().__init__()
self.norm = RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.attn = MiniLlamaAttention(hidden_size, num_heads, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.attn(self.norm(x))
class MLPBlock(nn.Module):
def __init__(self, hidden_size: int, intermediate_mult: int, *, dtype: torch.dtype, device: str):
super().__init__()
self.norm = RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.mlp = MiniLlamaMLP(hidden_size, intermediate_mult, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.mlp(self.norm(x))
class TorchMHC(nn.Module):
def __init__(
self,
layer: nn.Module,
*,
hc: int,
c: int,
tmax: int,
rms_eps: float,
pre_eps: float,
sinkhorn_eps: float,
post_mult: float,
phi_dtype: torch.dtype,
):
super().__init__()
self.layer = layer
self.hc = int(hc)
self.c = int(c)
self.tmax = int(tmax)
self.rms_eps = float(rms_eps)
self.pre_eps = float(pre_eps)
self.sinkhorn_eps = float(sinkhorn_eps)
self.post_mult = float(post_mult)
layer_param = next(layer.parameters())
device = layer_param.device
m = hc * hc + 2 * hc
k = hc * c
self.phi = nn.Parameter(torch.randn(k, m, dtype=phi_dtype, device=device) * 0.02)
self.b = nn.Parameter(torch.zeros(m, dtype=torch.float32, device=device))
self.alpha_pre = nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
self.alpha_post = nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
self.alpha_res = nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
self.layer_dtype = layer_param.dtype
def _coeffs(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
from test.transformers.test_mhc import mhc_coeffs_ref
return mhc_coeffs_ref(
x,
self.phi,
self.b,
self.alpha_pre,
self.alpha_post,
self.alpha_res,
tmax=self.tmax,
rms_eps=self.rms_eps,
pre_eps=self.pre_eps,
sinkhorn_eps=self.sinkhorn_eps,
post_mult=self.post_mult,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h_pre, h_post, h_res = self._coeffs(x)
x_in = (x.float() * h_pre.unsqueeze(-1)).sum(dim=-2)
if x_in.dtype != self.layer_dtype:
x_in = x_in.to(self.layer_dtype)
f_out = self.layer(x_in)
x_out = torch.einsum("...oi,...ic->...oc", h_res, x.float()) + h_post.unsqueeze(-1) * f_out.float().unsqueeze(
-2
)
return x_out.to(x.dtype)
class MHCDecoderLayer(nn.Module):
def __init__(
self,
mhc_cls: type[nn.Module],
*,
hidden_size: int,
hc: int,
num_heads: int,
intermediate_mult: int,
tmax: int,
dtype: torch.dtype,
device: str,
):
super().__init__()
attn = AttentionBlock(hidden_size, num_heads, dtype=dtype, device=device)
mlp = MLPBlock(hidden_size, intermediate_mult, dtype=dtype, device=device)
self.attn = mhc_cls(
attn,
hc=hc,
c=hidden_size,
tmax=tmax,
rms_eps=1e-6,
pre_eps=1e-4,
sinkhorn_eps=1e-6,
post_mult=2.0,
phi_dtype=dtype,
)
self.mlp = mhc_cls(
mlp,
hc=hc,
c=hidden_size,
tmax=tmax,
rms_eps=1e-6,
pre_eps=1e-4,
sinkhorn_eps=1e-6,
post_mult=2.0,
phi_dtype=dtype,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.attn(x)
x = self.mlp(x)
return x
class BenchMiniMHCLM(nn.Module):
def __init__(
self,
mhc_cls: type[nn.Module],
*,
vocab_size: int,
hidden_size: int,
hc: int,
num_layers: int,
num_heads: int,
intermediate_mult: int,
tmax: int,
dtype: torch.dtype,
device: str,
):
super().__init__()
self.hc = hc
self.hidden_size = hidden_size
self.embed = nn.Embedding(vocab_size, hc * hidden_size, dtype=dtype, device=device)
self.layers = nn.ModuleList(
[
MHCDecoderLayer(
mhc_cls,
hidden_size=hidden_size,
hc=hc,
num_heads=num_heads,
intermediate_mult=intermediate_mult,
tmax=tmax,
dtype=dtype,
device=device,
)
for _ in range(num_layers)
]
)
self.final_norm = RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False, dtype=dtype, device=device)
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
x = self.embed(input_ids)
bsz, seq_len, _ = x.shape
x = x.view(bsz, seq_len, self.hc, self.hidden_size)
for layer in self.layers:
x = layer(x)
x = x.mean(dim=-2)
x = self.final_norm(x)
return self.lm_head(x)
def _build_model(
provider: str,
*,
hidden_size: int,
hc: int,
num_layers: int,
num_heads: int,
intermediate_mult: int,
vocab_size: int,
tmax: int,
dtype: torch.dtype,
):
mhc_cls = LigerMHC if provider == "liger" else TorchMHC
return BenchMiniMHCLM(
mhc_cls,
vocab_size=vocab_size,
hidden_size=hidden_size,
hc=hc,
num_layers=num_layers,
num_heads=num_heads,
intermediate_mult=intermediate_mult,
tmax=tmax,
dtype=dtype,
device=device,
)
def bench_speed_mhc_lm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
hidden_size = int(input.x)
provider = input.kernel_provider
mode = input.kernel_operation_mode
extra = input.extra_benchmark_config
bsz = extra["B"]
seq_len = extra["T"]
hc = extra["HC"]
num_layers = extra["layers"]
num_heads = extra["heads"]
vocab_size = extra["vocab"]
dtype = extra["dtype"]
tmax = extra["tmax"]
intermediate_mult = extra["intermediate_mult"]
if hidden_size % num_heads != 0:
raise ValueError("hidden_size must be divisible by num_heads")
model = _build_model(
provider,
hidden_size=hidden_size,
hc=hc,
num_layers=num_layers,
num_heads=num_heads,
intermediate_mult=intermediate_mult,
vocab_size=vocab_size,
tmax=tmax,
dtype=dtype,
)
input_ids = torch.randint(0, vocab_size, (bsz, seq_len), device=device)
def fwd():
return model(input_ids)
def fwd_loss():
return fwd().float().mean()
grad_to_none = list(model.parameters())
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, grad_to_none=grad_to_none, rep=100)
elif mode == "backward":
loss = fwd_loss()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: loss.backward(retain_graph=True),
quantiles=QUANTILES,
grad_to_none=grad_to_none,
rep=100,
)
elif mode == "full":
def full():
loss = fwd_loss()
loss.backward()
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=grad_to_none, rep=100)
else:
raise ValueError(f"Unknown mode: {mode}")
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
def bench_memory_mhc_lm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
hidden_size = int(input.x)
provider = input.kernel_provider
extra = input.extra_benchmark_config
bsz = extra["B"]
seq_len = extra["T"]
hc = extra["HC"]
num_layers = extra["layers"]
num_heads = extra["heads"]
vocab_size = extra["vocab"]
dtype = extra["dtype"]
tmax = extra["tmax"]
intermediate_mult = extra["intermediate_mult"]
if hidden_size % num_heads != 0:
raise ValueError("hidden_size must be divisible by num_heads")
model = _build_model(
provider,
hidden_size=hidden_size,
hc=hc,
num_layers=num_layers,
num_heads=num_heads,
intermediate_mult=intermediate_mult,
vocab_size=vocab_size,
tmax=tmax,
dtype=dtype,
)
input_ids = torch.randint(0, vocab_size, (bsz, seq_len), device=device)
def fwd():
return model(input_ids)
def full():
loss = fwd().float().mean()
loss.backward()
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_configs = {
"kernel_name": "mhc_llama_like_lm",
"x_name": "hidden_size",
"x_label": "hidden_size",
"x_values": [256, 512, 1024],
"kernel_providers": ["liger", "torch"],
"extra_benchmark_configs": [
{
"B": 2,
"T": 256,
"HC": 4,
"layers": 2,
"heads": 8,
"vocab": 4096,
"dtype": torch.bfloat16,
"tmax": 8,
"intermediate_mult": 4,
}
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_mhc_lm,
kernel_operation_modes=["forward", "backward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_mhc_lm,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
"""
Standardized benchmark model configurations.
Provides canonical model architecture profiles and device-specific benchmark
parameters. All benchmark scripts should derive their tensor shapes from these
shared configs rather than defining ad-hoc per-script constants.
Usage::
from benchmark_model_configs import (
get_benchmark_model_config,
compute_seq_len_sweep_config,
estimate_kernel_peak_memory,
)
args = parse_benchmark_script_args()
model = get_benchmark_model_config(args.model)
# Measure actual memory via a small probe, then compute sweep config
peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe)
bpt = peak_bytes // probe_num_tokens
config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=bpt)
"""
import gc
import math
from dataclasses import dataclass
from typing import Callable
from typing import Dict
from typing import Optional
import torch
from liger_kernel.utils import get_total_gpu_memory
from liger_kernel.utils import infer_device
@dataclass(frozen=True)
class ModelConfig:
"""Canonical model architecture profile.
Each field corresponds to a standard LLM hyperparameter. Benchmark scripts
pick the fields they need (e.g. hidden_size for RMSNorm, vocab_size for
CrossEntropy) while kernel-specific overrides (e.g. hidden_act for GEGLU)
are applied locally in the benchmark script.
"""
name: str
hidden_size: int
intermediate_size: int
vocab_size: int
num_attention_heads: int
num_key_value_heads: int
head_dim: int
hidden_act: str
max_position_embeddings: int = 8192
rms_norm_eps: float = 1e-5
dtype: torch.dtype = torch.bfloat16
@dataclass(frozen=True)
class SeqLenSweepConfig:
"""Config for benchmarks that sweep sequence length (e.g. GEGLU, SwiGLU).
Attributes:
batch_size: Safe batch size for the sweep.
seq_len: Max sequence length (upper bound for x_values).
"""
batch_size: int
seq_len: int
@dataclass(frozen=True)
class HiddenSizeSweepConfig:
"""Config for benchmarks that sweep hidden_size with fixed BT (e.g. DyT).
Attributes:
bt: Fixed batch * seq dimension.
max_hidden_size: Upper bound for hidden_size sweep.
"""
bt: int
max_hidden_size: int
# ── Model Profiles ──────────────────────────────────────────────────────────
LLAMA_2_7B = ModelConfig(
name="llama_2_7b",
hidden_size=4096,
intermediate_size=11008,
vocab_size=32000,
num_attention_heads=32,
num_key_value_heads=32,
head_dim=128,
hidden_act="silu",
max_position_embeddings=4096,
)
LLAMA_3_8B = ModelConfig(
name="llama_3_8b",
hidden_size=4096,
intermediate_size=14336,
vocab_size=128256,
num_attention_heads=32,
num_key_value_heads=8,
head_dim=128,
hidden_act="silu",
max_position_embeddings=8192,
)
MODEL_REGISTRY: Dict[str, ModelConfig] = {
"llama_2_7b": LLAMA_2_7B,
"llama_3_8b": LLAMA_3_8B,
}
DEFAULT_MODEL_CONFIG = LLAMA_3_8B
def get_benchmark_model_config(model_name: Optional[str] = None) -> ModelConfig:
"""Resolve benchmark model config from name.
Returns the canonical model architecture profile (hidden_size, vocab_size,
dtype, etc.) for benchmark runs. Use this to obtain model attributes
when building benchmark tensors and shapes.
Args:
model_name: Registry key (e.g. ``llama_2_7b``, ``llama_3_8b``).
If None, returns ``DEFAULT_MODEL_CONFIG``.
"""
return MODEL_REGISTRY[model_name] if model_name else DEFAULT_MODEL_CONFIG
def estimate_kernel_peak_memory(probe_fn: Callable[[], torch.Tensor]) -> int:
"""Run a forward + backward probe to measure peak memory (bytes).
Call this with the *pure PyTorch* (e.g. huggingface) implementation --
that typically has the highest memory footprint and therefore gives a
safe upper-bound estimate. Returns the total peak bytes; divide by
num_tokens if you need bytes-per-token for :func:`compute_seq_len_sweep_config`.
The probe_fn performs setup and forward pass internally; cleanup is
automatic, so callers do not need to manage tensor/layer lifecycle.
Example::
peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe)
kernel_bpt = peak_bytes // num_tokens # if needed
Args:
probe_fn: Callable that performs setup, runs a forward pass, and
returns an output tensor suitable for ``.backward()``.
"""
device_str = infer_device()
torch_device_mod = getattr(torch, device_str)
gc.collect()
torch_device_mod.empty_cache()
torch_device_mod.memory.reset_peak_memory_stats()
y = probe_fn()
y.backward(torch.randn_like(y))
peak_bytes = torch_device_mod.max_memory_allocated()
del y
gc.collect()
torch_device_mod.empty_cache()
return max(1, peak_bytes)
def compute_seq_len_sweep_config(
model_cfg: ModelConfig,
kernel_bytes_per_token: Optional[int] = None,
memory_utilization: float = 0.4,
max_seq_len: Optional[int] = None,
max_batch_size: int = 32,
) -> SeqLenSweepConfig:
"""Compute safe batch_size and seq_len for sequence-length sweep (e.g. GEGLU).
Peak memory is estimated as
``batch_size * seq_len * kernel_bytes_per_token`` and is capped at
device memory * memory_utilization. Device memory is obtained
internally via :func:`~liger_kernel.utils.get_total_gpu_memory`.
Prefer obtaining *kernel_bytes_per_token* via
:func:`estimate_kernel_peak_memory` (divide by num_tokens) rather
than hardcoding an analytical estimate.
Args:
model_cfg: Model architecture config.
kernel_bytes_per_token: Peak memory **per token** (``batch * seq_len``
axis). Best obtained from :func:`estimate_kernel_peak_memory` / num_tokens.
Falls back to a conservative heuristic
(``hidden_size * dtype_bytes * 16``) when *None*.
memory_utilization: Fraction of total device memory to target (0 to 1).
Lower values are safer. Default ``0.4`` leaves headroom for
framework overhead and CUDA/NPU context.
max_seq_len: Hard upper bound for sequence length. Defaults to
``model_cfg.max_position_embeddings`` so the sweep never exceeds
the model's native context window.
max_batch_size: Hard upper bound for batch size.
"""
total_memory_gb = get_total_gpu_memory()
dtype_bytes = 2 if model_cfg.dtype in (torch.bfloat16, torch.float16) else 4
if kernel_bytes_per_token is None:
kernel_bytes_per_token = model_cfg.hidden_size * dtype_bytes * 16
if max_seq_len is None:
max_seq_len = model_cfg.max_position_embeddings
usable_bytes = total_memory_gb * (1024**3) * memory_utilization
max_tokens = max(1, int(usable_bytes / kernel_bytes_per_token))
seq_len = min(max_seq_len, max_tokens)
seq_len = 2 ** int(math.log2(seq_len)) if seq_len >= 1024 else 1024
batch_size = max(1, min(max_tokens // seq_len, max_batch_size))
return SeqLenSweepConfig(batch_size=batch_size, seq_len=seq_len)
def compute_hidden_size_sweep_config(
model_cfg: ModelConfig,
kernel_peak_bytes: int,
bt: int = 4096,
memory_utilization: float = 0.4,
max_hidden_size_multiplier: int = 4,
) -> HiddenSizeSweepConfig:
"""Compute safe max_hidden_size for hidden_size sweep (e.g. DyT).
For kernels with shape (BT, hidden_size) where BT is fixed and we sweep
hidden_size. Uses probe peak memory to derive max_hidden_size.
Device memory is obtained internally via :func:`~liger_kernel.utils.get_total_gpu_memory`.
Args:
model_cfg: Model config.
kernel_peak_bytes: Peak memory from probe (BT, model.hidden_size).
bt: Fixed BT dimension; must match the probe.
memory_utilization: Fraction of device memory to use.
max_hidden_size_multiplier: Cap max_hidden_size at model.hidden_size * this.
"""
total_memory_gb = get_total_gpu_memory()
usable_bytes = total_memory_gb * (1024**3) * memory_utilization
kernel_bpt = max(1, kernel_peak_bytes // bt)
max_hidden_size = min(
model_cfg.hidden_size * max_hidden_size_multiplier,
max(
model_cfg.hidden_size,
int(usable_bytes * model_cfg.hidden_size / (bt * kernel_bpt)),
),
)
max_hidden_size = max(1024, 2 ** int(math.log2(max_hidden_size)))
return HiddenSizeSweepConfig(bt=bt, max_hidden_size=max_hidden_size)
import torch
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.transformers.multi_token_attention import LigerMultiTokenAttention
from liger_kernel.utils import infer_device
device = infer_device()
class TorchMultiTokenAttention(torch.nn.Module):
def __init__(self, C_in, C_out, K, groups, bias, dtype, device):
super().__init__()
self.weight = torch.nn.Parameter(torch.empty(C_out, C_in // groups, K, K, dtype=dtype, device=device))
self.bias = torch.nn.Parameter(torch.empty(C_out, dtype=dtype, device=device)) if bias else None
self.K = K
self.groups = groups
def forward(self, scores):
B, C_in, L, _ = scores.shape
mask = torch.tril(torch.ones(L, L, dtype=torch.bool, device=scores.device)).view(1, 1, L, L)
inf = torch.tensor(-1e9, device=scores.device, dtype=scores.dtype)
zero = torch.tensor(0.0, device=scores.device, dtype=scores.dtype)
s_inf = scores.masked_fill(~mask, inf)
probs = torch.nn.functional.softmax(s_inf, dim=-1)
out_c = torch.nn.functional.conv2d(
probs, self.weight, self.bias, stride=1, padding=self.K // 2, groups=self.groups
)
return out_c.masked_fill(~mask, zero)
def bench_speed_multi_token_attention(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
L = input.x
provider = input.kernel_provider
mode = input.kernel_operation_mode
extra_benchmark_config = input.extra_benchmark_config
B = extra_benchmark_config["B"]
C_in = extra_benchmark_config["C_in"]
C_out = extra_benchmark_config["C_out"]
K = extra_benchmark_config["K"]
groups = extra_benchmark_config["groups"]
bias = extra_benchmark_config["bias"]
dtype = extra_benchmark_config["dtype"]
x_shape = (B, C_in, L, L)
triton_attn = (
LigerMultiTokenAttention(
in_channels=C_in,
out_channels=C_out,
kernel_size=K,
stride=1,
padding=K // 2,
dilation=1,
groups=groups,
bias=bias,
)
.to(device)
.to(dtype)
)
torch_attn = TorchMultiTokenAttention(
C_in=C_in, C_out=C_out, K=K, groups=groups, bias=bias, dtype=dtype, device=device
)
with torch.no_grad():
torch_attn.weight.copy_(triton_attn.weight)
if bias:
torch_attn.bias.copy_(triton_attn.bias)
x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)
def fwd():
if provider == "liger":
return triton_attn(x)
elif provider == "torch":
return torch_attn(x)
print(f"Starting Warmup for input size: {x_shape}")
_ = fwd()
if mode in ("backward", "full"):
y = _
y.backward(dy, retain_graph=True)
print("Done Warmup")
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, grad_to_none=[x], rep=100, quantiles=QUANTILES)
elif mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(dy, retain_graph=True),
grad_to_none=[x],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":
def full():
y = fwd()
y.backward(dy, retain_graph=True)
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=[x], rep=100, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
def bench_memory_multi_token_attention(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
L = input.x
provider = input.kernel_provider
extra_benchmark_config = input.extra_benchmark_config
B = extra_benchmark_config["B"]
C_in = extra_benchmark_config["C_in"]
C_out = extra_benchmark_config["C_out"]
K = extra_benchmark_config["K"]
groups = extra_benchmark_config["groups"]
bias = extra_benchmark_config["bias"]
dtype = extra_benchmark_config["dtype"]
x_shape = (B, C_in, L, L)
triton_attn = (
LigerMultiTokenAttention(
in_channels=C_in,
out_channels=C_out,
kernel_size=K,
stride=1,
padding=K // 2,
dilation=1,
groups=groups,
bias=bias,
)
.to(device)
.to(dtype)
)
torch_attn = TorchMultiTokenAttention(
C_in=C_in, C_out=C_out, K=K, groups=groups, bias=bias, dtype=dtype, device=device
)
with torch.no_grad():
torch_attn.weight.copy_(triton_attn.weight)
if bias:
torch_attn.bias.copy_(triton_attn.bias)
x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)
def fwd():
if provider == "liger":
return triton_attn(x)
elif provider == "torch":
return torch_attn(x)
def full():
y = fwd()
y.backward(dy, retain_graph=True)
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_configs = {
"kernel_name": "multi_token_attention",
"x_name": "L",
"x_label": "sequence length",
"x_values": [2**i for i in range(5, 10)],
"kernel_providers": ["liger", "torch"],
"extra_benchmark_configs": [
{
"B": 2,
"C_in": 4,
"C_out": 4,
"K": 3,
"groups": 1,
"bias": True,
"dtype": torch.bfloat16,
}
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_multi_token_attention,
kernel_operation_modes=["forward", "full", "backward"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_multi_token_attention,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
import os
import sys
import torch
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.utils import infer_device
device = infer_device()
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
#############################################################################
# Test the memory consumption of the linear fused cross entropy loss
#############################################################################
def bench_memory_fused_linear_orpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_orpo_loss import LigerLMHeadORPO
from test.chunked_loss.test_orpo_loss import TorchLMHeadORPO
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider
# Instantiate once and retrieve the first output only
torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device)
torch_fwd = lambda x, target, nll_target: torch_lm_head_orpo(x, target, nll_target)[0]
liger_fwd = lambda x, target, nll_target: liger_lm_head_orpo(x, target, nll_target)[0]
_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
nll_target = torch.randint(V, (B, T), dtype=torch.long, device=device)
def fwd():
if provider == "liger":
return liger_fwd(_input, target, nll_target)
elif provider == "huggingface":
return torch_fwd(_input, target, nll_target)
def full():
y = fwd()
y.backward()
mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
# #############################################################################
# # Test the speed of the fused linear cross entropy loss
# #############################################################################
def bench_speed_fused_linear_orpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_orpo_loss import LigerLMHeadORPO
from test.chunked_loss.test_orpo_loss import TorchLMHeadORPO
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider
mode = input.kernel_operation_mode
# Instantiate once and retrieve the first output only
torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device)
torch_fwd = lambda x, target, nll_target: torch_lm_head_orpo(x, target, nll_target)[0]
liger_fwd = lambda x, target, nll_target: liger_lm_head_orpo(x, target, nll_target)[0]
_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
nll_target = torch.randint(V, (B, T), dtype=torch.long, device=device)
def fwd():
if provider == "liger":
return liger_fwd(_input, target, nll_target)
elif provider == "huggingface":
return torch_fwd(_input, target, nll_target)
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=[_input],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":
def full():
y = fwd()
y.backward()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
rep=100,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_configs = {
"kernel_name": "fused_linear_orpo_loss",
"x_name": "B",
"x_label": "B",
"x_values": [2**i for i in range(1, 5)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"T": 1024,
"H": 4096,
"V": 128256,
"mode": "forward",
"dtype": torch.bfloat16,
}
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_fused_linear_orpo_loss,
kernel_operation_modes=["forward", "full", "backward"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_fused_linear_orpo_loss,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
import torch
import torch.nn as nn
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.transformers.poly_norm import LigerPolyNorm
from liger_kernel.utils import infer_device
device = infer_device()
class NaivePolyNorm(nn.Module):
"""
Naive PyTorch implementation of PolyNorm.
Reference:
https://github.com/BryceZhuo/PolyCom/
PolyNorm formula:
y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
where norm(u) = u / sqrt(mean(u²) + ε)
"""
def __init__(self, eps=1e-6):
super().__init__()
# Align with PolyCom reference: (1/3, 1/3, 1/3) and bias=1.0
self.weight = nn.Parameter(torch.full((3,), 1.0 / 3.0))
self.bias = nn.Parameter(torch.tensor(1.0))
self.variance_epsilon = eps
def _norm(self, x):
"""RMSNorm operation"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon)
def forward(self, hidden_states):
"""
Forward pass of PolyNorm
Args:
hidden_states: input tensor of shape (..., H)
Returns:
output tensor of same shape as input
"""
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
# Compute powers
x_pow3 = hidden_states**3
x_pow2 = hidden_states**2
x_pow1 = hidden_states**1
# Normalize each power
norm_x3 = self._norm(x_pow3)
norm_x2 = self._norm(x_pow2)
norm_x1 = self._norm(x_pow1)
# Weighted sum with bias
output = self.weight[0] * norm_x3 + self.weight[1] * norm_x2 + self.weight[2] * norm_x1 + self.bias
return output.to(input_dtype)
def bench_speed_poly_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
N = input.x
provider = input.kernel_provider
mode = input.kernel_operation_mode
extra_benchmark_config = input.extra_benchmark_config
M = extra_benchmark_config["M"]
eps = extra_benchmark_config["eps"]
dtype = extra_benchmark_config["dtype"]
x_shape = (M, N)
triton_poly = LigerPolyNorm(eps=eps).to(device)
naive_poly = NaivePolyNorm(eps=eps).to(device)
x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)
# utility functions
def y_fwd():
if provider == "liger":
return triton_poly(x)
if provider == "huggingface":
return naive_poly(x)
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
y_fwd,
grad_to_none=[x],
rep=500,
quantiles=QUANTILES,
)
elif mode == "backward":
y = y_fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(dy, retain_graph=True),
grad_to_none=[x],
rep=500,
quantiles=QUANTILES,
)
elif mode == "full":
def full():
y = y_fwd()
y.backward(dy, retain_graph=True)
ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
grad_to_none=[x],
rep=500,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
def bench_memory_poly_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
N = input.x
provider = input.kernel_provider
extra_benchmark_config = input.extra_benchmark_config
M = extra_benchmark_config["M"]
eps = extra_benchmark_config["eps"]
dtype = extra_benchmark_config["dtype"]
x_shape = (M, N)
triton_poly = LigerPolyNorm(eps=eps).to(device)
naive_poly = NaivePolyNorm(eps=eps).to(device)
x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)
# utility functions
def y_fwd():
if provider == "liger":
return triton_poly(x)
if provider == "huggingface":
return naive_poly(x)
def full():
y = y_fwd()
y.backward(dy, retain_graph=True)
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_configs = {
"kernel_name": "poly_norm",
"x_name": "H",
"x_label": "hidden size",
"x_values": [2**i for i in range(10, 16)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [{"M": 2048, "dtype": torch.bfloat16, "eps": 1e-6}],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_poly_norm,
kernel_operation_modes=["forward", "full", "backward"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_poly_norm,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
import torch
import triton
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLTextConfig
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLRotaryEmbedding
from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
from liger_kernel.utils import infer_device
device = infer_device()
def bench_speed_qwen2vl_mrope(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
provider = input.kernel_provider
mode = input.kernel_operation_mode
extra_benchmark_config = input.extra_benchmark_config
num_q_heads = extra_benchmark_config["num_q_heads"]
num_kv_heads = extra_benchmark_config["num_kv_heads"]
dtype = extra_benchmark_config["dtype"]
# x can be either hidden_size or seq_len
hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x
seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x
head_dim = hidden_size // num_q_heads
mrope_section_hw = head_dim * 3 // 16
mrope_section = [
head_dim // 2 - 2 * mrope_section_hw,
mrope_section_hw,
mrope_section_hw,
]
config = Qwen2VLTextConfig(
hidden_size=hidden_size,
num_attention_heads=num_q_heads,
num_key_value_heads=num_kv_heads,
rope_theta=1000000.0,
mrope_section=mrope_section,
)
rotary_emb = Qwen2VLRotaryEmbedding(config, device=device)
q = torch.randn(
(1, seq_len, num_q_heads, head_dim),
device=device,
requires_grad=True,
dtype=dtype,
).transpose(1, 2)
k = torch.randn(
(1, seq_len, num_kv_heads, head_dim),
device=device,
requires_grad=True,
dtype=dtype,
).transpose(1, 2)
dq, dk = (
torch.randn_like(q, device=device, dtype=dtype),
torch.randn_like(k, device=device, dtype=dtype),
)
pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1)
cos, sin = rotary_emb(k, pos_ids)
def fwd():
if provider == "liger":
return liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section)
elif provider == "huggingface":
return apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section)
else:
raise ValueError(f"Invalid provider: {provider} for M-RoPE embedding")
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
grad_to_none=[q, k],
rep=400,
quantiles=QUANTILES,
)
elif mode == "backward":
q_out, k_out = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True),
grad_to_none=[q, k],
rep=400,
quantiles=QUANTILES,
)
elif mode == "full":
def full():
q_out, k_out = fwd()
torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True)
ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
grad_to_none=[q, k],
rep=400,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
def bench_memory_qwen2vl_mrope(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
provider = input.kernel_provider
extra_benchmark_config = input.extra_benchmark_config
num_q_heads = extra_benchmark_config["num_q_heads"]
num_kv_heads = extra_benchmark_config["num_kv_heads"]
dtype = extra_benchmark_config["dtype"]
# x can be either hidden_size or seq_len
hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x
seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x
head_dim = hidden_size // num_q_heads
mrope_section_hw = head_dim * 3 // 16
mrope_section = [
head_dim // 2 - 2 * mrope_section_hw,
mrope_section_hw,
mrope_section_hw,
]
config = Qwen2VLTextConfig(
hidden_size=hidden_size,
num_attention_heads=num_q_heads,
num_key_value_heads=num_kv_heads,
rope_theta=1000000.0,
mrope_section=mrope_section,
)
rotary_emb = Qwen2VLRotaryEmbedding(config, device=device)
q = torch.randn(
(1, seq_len, num_q_heads, head_dim),
device=device,
requires_grad=True,
dtype=dtype,
).transpose(1, 2)
k = torch.randn(
(1, seq_len, num_kv_heads, head_dim),
device=device,
requires_grad=True,
dtype=dtype,
).transpose(1, 2)
dq, dk = (
torch.randn_like(q, device=device, dtype=dtype),
torch.randn_like(k, device=device, dtype=dtype),
)
pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1)
cos, sin = rotary_emb(k, pos_ids)
def full():
if provider == "liger":
q_out, k_out = liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section)
else:
q_out, k_out = apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section)
torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True)
mem_50, mem_20, mem_80 = _test_memory(
full,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_configs_varying_hidden_size = {
"kernel_name": "qwen2vl_mrope",
"x_name": "H",
"x_label": "hidden size",
"x_values": [32 * (2**i) for i in range(4, 10, 2)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"dtype": torch.bfloat16,
"seq_len": 2048,
"num_q_heads": 32,
"num_kv_heads": 8,
}
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_qwen2vl_mrope,
kernel_operation_modes=["forward", "backward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs_varying_hidden_size,
)
run_benchmarks(
bench_test_fn=bench_memory_qwen2vl_mrope,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs_varying_hidden_size,
)
common_configs_varying_seq_len = {
"kernel_name": "qwen2vl_mrope",
"x_name": "T",
"x_label": "sequence length",
"x_values": [2**i for i in range(10, 15)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"dtype": torch.bfloat16,
"hidden_size": 8192,
"num_q_heads": 32,
"num_kv_heads": 8,
}
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_qwen2vl_mrope,
kernel_operation_modes=["forward", "backward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs_varying_seq_len,
)
run_benchmarks(
bench_test_fn=bench_memory_qwen2vl_mrope,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs_varying_seq_len,
)
import torch
import torch.nn as nn
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.utils import infer_device
device = infer_device()
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def bench_speed_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
N = input.x
provider = input.kernel_provider
mode = input.kernel_operation_mode
extra_benchmark_config = input.extra_benchmark_config
M = extra_benchmark_config["M"]
eps = extra_benchmark_config["eps"]
dtype = extra_benchmark_config["dtype"]
x_shape = (M, N)
triton_rms = LigerRMSNorm(hidden_size=N, eps=eps).to(device)
llama_rms = LlamaRMSNorm(hidden_size=N, eps=eps).to(device)
x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)
# utility functions
def y_fwd():
if provider == "liger":
return triton_rms(x)
if provider == "huggingface":
return llama_rms(x)
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
y_fwd,
grad_to_none=[x],
rep=500,
quantiles=QUANTILES,
)
elif mode == "backward":
y = y_fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(dy, retain_graph=True),
grad_to_none=[x],
rep=500,
quantiles=QUANTILES,
)
elif mode == "full":
def full():
y = y_fwd()
y.backward(dy, retain_graph=True)
ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
grad_to_none=[x],
rep=500,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
def bench_memory_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
N = input.x
provider = input.kernel_provider
extra_benchmark_config = input.extra_benchmark_config
M = extra_benchmark_config["M"]
eps = extra_benchmark_config["eps"]
dtype = extra_benchmark_config["dtype"]
x_shape = (M, N)
triton_rms = LigerRMSNorm(hidden_size=N, eps=eps).to(device)
llama_rms = LlamaRMSNorm(hidden_size=N, eps=eps).to(device)
x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)
# utility functions
def y_fwd():
if provider == "liger":
return triton_rms(x)
if provider == "huggingface":
return llama_rms(x)
def full():
y = y_fwd()
y.backward(dy, retain_graph=True)
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_configs = {
"kernel_name": "rms_norm",
"x_name": "H",
"x_label": "hidden size",
"x_values": [2**i for i in range(10, 16)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [{"M": 2048, "dtype": torch.bfloat16, "eps": 1e-6}],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_rms_norm,
kernel_operation_modes=["forward", "full", "backward"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_rms_norm,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
import torch
import triton
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.utils import infer_device
from liger_kernel.utils import transformers_version_dispatch
device = infer_device()
def bench_speed_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
provider = input.kernel_provider
mode = input.kernel_operation_mode
extra_benchmark_config = input.extra_benchmark_config
num_q_heads = extra_benchmark_config["num_q_heads"]
num_kv_heads = extra_benchmark_config["num_kv_heads"]
dtype = extra_benchmark_config["dtype"]
# x can be either hidden_size or seq_len
hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x
seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x
head_dim = hidden_size // num_q_heads
rotary_emb = transformers_version_dispatch(
"4.48.0",
LlamaRotaryEmbedding,
LlamaRotaryEmbedding,
before_kwargs={"dim": head_dim, "device": device},
after_kwargs={"config": LlamaConfig(num_kv_heads=num_kv_heads, head_dim=head_dim), "device": device},
)
q = torch.randn(
(1, seq_len, num_q_heads, head_dim),
device=device,
requires_grad=True,
dtype=dtype,
).transpose(1, 2)
k = torch.randn(
(1, seq_len, num_kv_heads, head_dim),
device=device,
requires_grad=True,
dtype=dtype,
).transpose(1, 2)
dq, dk = (
torch.randn_like(q, device=device, dtype=dtype),
torch.randn_like(k, device=device),
)
pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
cos, sin = rotary_emb(k, pos_ids)
def fwd():
if provider == "liger":
return liger_rotary_pos_emb(q, k, cos, sin, pos_ids)
elif provider == "huggingface":
return apply_rotary_pos_emb(q, k, cos, sin, pos_ids)
else:
raise ValueError(f"Invalid provider: {provider} for RoPE embedding")
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
grad_to_none=[q, k],
rep=400,
quantiles=QUANTILES,
)
elif mode == "backward":
q_out, k_out = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True),
grad_to_none=[q, k],
rep=400,
quantiles=QUANTILES,
)
elif mode == "full":
def full():
q_out, k_out = fwd()
torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True)
ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
grad_to_none=[q, k],
rep=400,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
def bench_memory_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
provider = input.kernel_provider
extra_benchmark_config = input.extra_benchmark_config
num_q_heads = extra_benchmark_config["num_q_heads"]
num_kv_heads = extra_benchmark_config["num_kv_heads"]
dtype = extra_benchmark_config["dtype"]
# x can be either hidden_size or seq_len
hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x
seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x
head_dim = hidden_size // num_q_heads
rotary_emb = transformers_version_dispatch(
"4.48.0",
LlamaRotaryEmbedding,
LlamaRotaryEmbedding,
before_kwargs={"dim": head_dim, "device": device},
after_kwargs={"config": LlamaConfig(num_kv_heads=num_kv_heads, head_dim=head_dim), "device": device},
)
q = torch.randn(
(1, seq_len, num_q_heads, head_dim),
device=device,
requires_grad=True,
dtype=dtype,
).transpose(1, 2)
k = torch.randn(
(1, seq_len, num_kv_heads, head_dim),
device=device,
requires_grad=True,
dtype=dtype,
).transpose(1, 2)
dq, dk = (
torch.randn_like(q, device=device, dtype=dtype),
torch.randn_like(k, device=device),
)
pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
cos, sin = rotary_emb(k, pos_ids)
def full():
if provider == "liger":
q_out, k_out = liger_rotary_pos_emb(q, k, cos, sin, pos_ids)
else:
q_out, k_out = apply_rotary_pos_emb(q, k, cos, sin, pos_ids)
torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True)
mem_50, mem_20, mem_80 = _test_memory(
full,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_configs_varying_hidden_size = {
"kernel_name": "rope",
"x_name": "H",
"x_label": "hidden size",
"x_values": [32 * (2**i) for i in range(4, 10, 2)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"dtype": torch.bfloat16,
"seq_len": 2048,
"num_q_heads": 32,
"num_kv_heads": 8,
}
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_rope,
kernel_operation_modes=["forward", "backward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs_varying_hidden_size,
)
run_benchmarks(
bench_test_fn=bench_memory_rope,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs_varying_hidden_size,
)
common_configs_varying_seq_len = {
"kernel_name": "rope",
"x_name": "T",
"x_label": "sequence length",
"x_values": [2**i for i in range(10, 15)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"dtype": torch.bfloat16,
"hidden_size": 8192,
"num_q_heads": 32,
"num_kv_heads": 8,
}
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_rope,
kernel_operation_modes=["forward", "backward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs_varying_seq_len,
)
run_benchmarks(
bench_test_fn=bench_memory_rope,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs_varying_seq_len,
)
import os
import sys
import torch
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.utils import infer_device
device = infer_device()
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
#############################################################################
# Test the memory consumption of the linear fused cross entropy loss
#############################################################################
def bench_memory_fused_linear_simpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_simpo_loss import LigerLMHeadSimPO
from test.chunked_loss.test_simpo_loss import TorchLMHeadCPO
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider
# Instantiate once and retrieve the first output only
torch_lm_head_simpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)
torch_fwd = lambda x, target: torch_lm_head_simpo(x, target)[0]
liger_fwd = lambda x, target: liger_lm_head_simpo(x, target)[0]
_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
def fwd():
if provider == "liger":
return liger_fwd(_input, target)
elif provider == "huggingface":
return torch_fwd(_input, target)
def full():
y = fwd()
y.backward()
mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
# #############################################################################
# # Test the speed of the fused linear cross entropy loss
# #############################################################################
def bench_speed_fused_linear_simpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_simpo_loss import LigerLMHeadSimPO
from test.chunked_loss.test_simpo_loss import TorchLMHeadCPO
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider
mode = input.kernel_operation_mode
# Instantiate once and retrieve the first output only
torch_lm_head_simpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)
torch_fwd = lambda x, target: torch_lm_head_simpo(x, target)[0]
liger_fwd = lambda x, target: liger_lm_head_simpo(x, target)[0]
_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
def fwd():
if provider == "liger":
return liger_fwd(_input, target)
elif provider == "huggingface":
return torch_fwd(_input, target)
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=[_input],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":
def full():
y = fwd()
y.backward()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
rep=100,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_configs = {
"kernel_name": "fused_linear_simpo_loss",
"x_name": "B",
"x_label": "B",
"x_values": [2**i for i in range(1, 5)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"T": 1024,
"H": 4096,
"V": 128256,
"mode": "forward",
"dtype": torch.bfloat16,
}
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_fused_linear_simpo_loss,
kernel_operation_modes=["forward", "full", "backward"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_fused_linear_simpo_loss,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
import torch
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.transformers.softmax import LigerSoftmax
from liger_kernel.utils import infer_device
device = infer_device()
def bench_speed_softmax(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
N = input.x
provider = input.kernel_provider
mode = input.kernel_operation_mode
extra_benchmark_config = input.extra_benchmark_config
M = extra_benchmark_config["M"]
dtype = extra_benchmark_config["dtype"]
x_shape = (M, N)
liger_softmax = LigerSoftmax().to(device).to(dtype)
torch_softmax = torch.nn.Softmax(dim=-1).to(device).to(dtype)
x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)
def y_fwd():
if provider == "liger":
return liger_softmax(x)
if provider == "torch":
return torch_softmax(x)
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500)
elif mode == "backward":
y = y_fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(dy, retain_graph=True),
quantiles=QUANTILES,
grad_to_none=[x],
rep=500,
)
elif mode == "full":
def full():
y = y_fwd()
y.backward(dy, retain_graph=True)
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=[x], rep=500)
if any(val is None for val in (ms_20, ms_50, ms_80)):
raise RuntimeError(f"Benchmark speed result is None: ms_20={ms_20}, ms_50={ms_50}, ms_80={ms_80}")
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
def bench_memory_softmax(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
shape = input.x
provider = input.kernel_provider
mode = input.kernel_operation_mode
extra_benchmark_config = input.extra_benchmark_config
dtype = extra_benchmark_config.get("dtype", torch.float32)
torch_softmax = torch.nn.Softmax(dim=-1)
liger_softmax = LigerSoftmax().to(device).to(dtype)
x = torch.randn(shape, device=device, dtype=dtype, requires_grad=True)
def fwd():
if provider == "liger":
return liger_softmax(x)
elif provider == "torch":
return torch_softmax(x)
else:
raise ValueError(f"Invalid provider: {provider} for softmax")
def full():
y = fwd()
y.backward(torch.ones_like(y), retain_graph=True)
if mode == "forward":
mem_50, mem_20, mem_80 = _test_memory(fwd, quantiles=QUANTILES)
elif mode == "backward":
do = torch.ones_like(x)
y = fwd()
mem_50, mem_20, mem_80 = _test_memory(lambda: y.backward(do, retain_graph=True), quantiles=QUANTILES)
else:
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
if any(val is None for val in (mem_20, mem_50, mem_80)):
raise RuntimeError(f"Benchmark memory result is None: mem_20={mem_20}, mem_50={mem_50}, mem_80={mem_80}")
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_configs = dict(
kernel_name="softmax",
x_name="N",
x_label="hidden size",
x_values=[128, 256, 512, 1024, 2048, 4096],
kernel_providers=["liger", "torch"],
extra_benchmark_configs=[
{"M": 2048, "dtype": torch.float32},
{"M": 2048, "dtype": torch.bfloat16},
],
)
run_benchmarks(
bench_test_fn=bench_speed_softmax,
kernel_operation_modes=["forward", "full", "backward"],
metric_name="speed",
metric_unit="ms",
overwrite=args.overwrite,
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_softmax,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
overwrite=args.overwrite,
**common_configs,
)
import torch
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.transformers.multi_token_attention import LigerMultiTokenAttention
from liger_kernel.utils import infer_device
device = infer_device()
class TorchSparseMultiTokenAttention(torch.nn.Module):
def __init__(self, C_in, C_out, K, groups, bias, dtype, device):
super().__init__()
self.weight = torch.nn.Parameter(torch.empty(C_out, C_in // groups, K, K, dtype=dtype, device=device))
self.bias = torch.nn.Parameter(torch.empty(C_out, dtype=dtype, device=device)) if bias else None
self.K = K
self.groups = groups
self.dtype = dtype
self.compute_dtype = torch.float32
def forward(self, scores):
B, C_in, L, _ = scores.shape
mask = torch.tril(torch.ones(L, L, dtype=torch.bool, device=scores.device)).view(1, 1, L, L)
inf = torch.tensor(-1e9, device=scores.device, dtype=self.compute_dtype)
zero = torch.tensor(0.0, device=scores.device, dtype=self.compute_dtype)
s_compute = scores.to(self.compute_dtype)
s_inf = s_compute.masked_fill(~mask, inf)
dim = -1
z = s_inf
z_sorted, _ = torch.sort(z, dim=dim, descending=True)
cum_sum = torch.cumsum(z_sorted, dim=dim)
k_indices = torch.arange(1, L + 1, device=z.device, dtype=z.dtype).view(1, 1, 1, L)
is_positive = z_sorted > -1e8
condition = (1 + k_indices * z_sorted > cum_sum) & is_positive
k_sparsemax = torch.sum(condition, dim=dim, keepdim=True)
k_sparsemax_safe = torch.max(k_sparsemax, torch.ones_like(k_sparsemax))
cum_sum_k = torch.gather(cum_sum, dim=dim, index=k_sparsemax_safe.long() - 1)
tau = (cum_sum_k - 1) / k_sparsemax_safe.to(z.dtype)
tau = torch.where(k_sparsemax == 0, torch.full_like(tau, float("inf")), tau)
probs = torch.clamp(z - tau, min=0)
weight_compute = self.weight.to(self.compute_dtype)
bias_compute = self.bias.to(self.compute_dtype) if self.bias is not None else None
out_c = torch.nn.functional.conv2d(
probs, weight_compute, bias_compute, stride=1, padding=self.K // 2, groups=self.groups
)
return out_c.masked_fill(~mask, zero).to(scores.dtype)
def bench_speed_sparse_multi_token_attention(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
L = input.x
provider = input.kernel_provider
mode = input.kernel_operation_mode
extra_benchmark_config = input.extra_benchmark_config
B = extra_benchmark_config["B"]
C_in = extra_benchmark_config["C_in"]
C_out = extra_benchmark_config["C_out"]
K = extra_benchmark_config["K"]
groups = extra_benchmark_config["groups"]
bias = extra_benchmark_config["bias"]
dtype = extra_benchmark_config["dtype"]
x_shape = (B, C_in, L, L)
liger_attn = (
LigerMultiTokenAttention(
in_channels=C_in,
out_channels=C_out,
kernel_size=K,
stride=1,
padding=K // 2,
dilation=1,
groups=groups,
bias=bias,
sparse=True,
)
.to(device)
.to(dtype)
)
torch_attn = TorchSparseMultiTokenAttention(
C_in=C_in, C_out=C_out, K=K, groups=groups, bias=bias, dtype=dtype, device=device
)
with torch.no_grad():
torch.nn.init.kaiming_uniform_(liger_attn.weight, a=5**0.5)
if bias:
torch.nn.init.zeros_(liger_attn.bias)
torch_attn.weight.copy_(liger_attn.weight)
if bias:
torch_attn.bias.copy_(liger_attn.bias)
x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)
def fwd():
if provider == "liger":
return liger_attn(x)
elif provider == "torch":
return torch_attn(x)
print(f"Starting Warmup for input size: {x_shape}")
_ = fwd()
if mode in ("backward", "full"):
y = _
y.backward(dy, retain_graph=True)
print("Done Warmup")
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, grad_to_none=[x], rep=100, quantiles=QUANTILES)
elif mode == "backward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(dy, retain_graph=True),
grad_to_none=[x],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":
def full():
y = fwd()
y.backward(dy, retain_graph=True)
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=[x], rep=100, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
def bench_memory_sparse_multi_token_attention(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
L = input.x
provider = input.kernel_provider
extra_benchmark_config = input.extra_benchmark_config
B = extra_benchmark_config["B"]
C_in = extra_benchmark_config["C_in"]
C_out = extra_benchmark_config["C_out"]
K = extra_benchmark_config["K"]
groups = extra_benchmark_config["groups"]
bias = extra_benchmark_config["bias"]
dtype = extra_benchmark_config["dtype"]
x_shape = (B, C_in, L, L)
liger_attn = (
LigerMultiTokenAttention(
in_channels=C_in,
out_channels=C_out,
kernel_size=K,
stride=1,
padding=K // 2,
dilation=1,
groups=groups,
bias=bias,
sparse=True,
)
.to(device)
.to(dtype)
)
torch_attn = TorchSparseMultiTokenAttention(
C_in=C_in, C_out=C_out, K=K, groups=groups, bias=bias, dtype=dtype, device=device
)
with torch.no_grad():
torch.nn.init.kaiming_uniform_(liger_attn.weight, a=5**0.5)
if bias:
torch.nn.init.zeros_(liger_attn.bias)
torch_attn.weight.copy_(liger_attn.weight)
if bias:
torch_attn.bias.copy_(liger_attn.bias)
x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)
def fwd():
if provider == "liger":
return liger_attn(x)
elif provider == "torch":
return torch_attn(x)
def full():
y = fwd()
y.backward(dy, retain_graph=True)
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_configs = {
"kernel_name": "sparse_multi_token_attention",
"x_name": "L",
"x_label": "sequence length",
"x_values": [2**i for i in range(5, 10)],
"kernel_providers": ["liger", "torch"],
"extra_benchmark_configs": [
{
"B": 2,
"C_in": 4,
"C_out": 4,
"K": 3,
"groups": 1,
"bias": True,
"dtype": torch.float32,
}
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_sparse_multi_token_attention,
kernel_operation_modes=["forward", "full", "backward"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_sparse_multi_token_attention,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment