Commit 9b0e3a30 authored by cmx's avatar cmx
Browse files

first commit

parent fe5cd1fc
Pipeline #3450 failed with stages
in 0 seconds
BSD 2-CLAUSE LICENSE
Copyright 2024 LinkedIn Corporation
All Rights Reserved.
Redistribution and use in source and binary forms, with or
without modification, are permitted provided that the following
conditions are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following
disclaimer in the documentation and/or other materials provided
with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
.PHONY: test checkstyle test-convergence all serve build clean
all: checkstyle test test-convergence
# Command to run pytest for correctness tests
test:
python -m pytest --disable-warnings \
--cov=src/liger_kernel \
--cov-report=term-missing \
--ignore=test/convergence \
test/
# Command to run coverage report
coverage:
coverage report -m
# Command to run ruff for linting and formatting code
checkstyle:
ruff check --output-format=concise .; ruff_check_status=$$?; \
ruff format --check --diff .; ruff_format_status=$$?; \
ruff check . --fix; \
ruff format .; \
if [ $$ruff_check_status -ne 0 ] || [ $$ruff_format_status -ne 0 ]; then \
exit 1; \
fi
# Command to run pytest for convergence tests
# We have to explicitly set HF_DATASETS_OFFLINE=1, or dataset will silently try to send metrics and timeout (80s) https://github.com/huggingface/datasets/blob/37a603679f451826cfafd8aae00738b01dcb9d58/src/datasets/load.py#L286
test-convergence:
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/fp32/test_mini_models.py
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/fp32/test_mini_models_multimodal.py
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/fp32/test_mini_models_with_logits.py
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models.py
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models_multimodal.py
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models_with_logits.py
# Command to run all benchmark scripts and update benchmarking data file
# By default this doesn't overwrite existing data for the same benchmark experiment
# run with `make run-benchmarks OVERWRITE=1` to overwrite existing benchmark data
BENCHMARK_DIR = benchmark/scripts
BENCHMARK_SCRIPTS = $(wildcard $(BENCHMARK_DIR)/benchmark_*.py)
OVERWRITE ?= 0
run-benchmarks:
@for script in $(BENCHMARK_SCRIPTS); do \
echo "Running benchmark: $$script"; \
if [ $(OVERWRITE) -eq 1 ]; then \
python $$script --overwrite; \
else \
python $$script; \
fi; \
done
# MkDocs Configuration
MKDOCS = mkdocs
CONFIG_FILE = mkdocs.yml
SITE_DIR = site
# MkDocs targets
# Serve the documentation
serve:
$(MKDOCS) serve -f $(CONFIG_FILE)
# Build the documentation into the specified site directory
build:
$(MKDOCS) build -f $(CONFIG_FILE) --site-dir $(SITE_DIR)
# Clean the output directory
clean:
rm -rf $(SITE_DIR)/
Copyright 2024 LinkedIn Corporation
All Rights Reserved.
Licensed under the BSD 2-Clause License (the "License"). See License in the project root for license information.
This product includes software developed by LinkedIn Corporation.
This product contains code derived from the following open source projects:
1. Unsloth
Copyright (c) 2023 Unsloth AI
Licensed under the Apache License, Version 2.0
Source: https://github.com/unslothai/unsloth
The `calculate_settings` function to determine block size and warp is reused for Norm and MLP operations.
Modifications and additions were made to the RMS Norm implementation.
2. Triton
Copyright (c) 2023 OpenAI
Licensed under the MIT License
Source: https://github.com/openai/triton
Modifications were made based on Triton tutorials for the RMS Norm implementation.
3. Efficient Cross Entropy
Copyright (c) 2023 Mohamed Malek
Licensed under the MIT License
Source: https://github.com/mgmalek/efficient_cross_entropy
The idea of gradient-in-forward and chunking was used in the Linear Cross Entropy implementation.
4. Flash Attention
Copyright (c) 2023 Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
Licensed under the BSD 3-Clause License
Source: https://github.com/Dao-AILab/flash-attention
Optimization ideas such as tiling and recomputation were inspired by this work.
5. AutoAWQ
Copyright (c) 2023 Casper Hansen
Licensed under the MIT License
Source: https://github.com/casper-hansen/AutoAWQ
The design of the automodel was referenced from this project.
6. llm.c
Copyright (c) 2023 Andrej Karpathy
Licensed under the MIT License
Source: https://github.com/karpathy/llm.c
The design of end-to-end testing was referenced from this project.
7. Tiny Shakespeare Dataset
Source: https://huggingface.co/datasets/karpathy/tiny_shakespeare
This dataset is used to conduct convergence tests on mini models.
For full license texts, please refer to the respective project repositories.
# Guideline for Adding Benchmark Scripts
This document describes how to add new benchmark scripts to Liger-Kernel in line with the shared framework.
## 1. Where and how to add a script
- **Location**: `benchmark/scripts/`
- **Naming**: `benchmark_<kernel_name>.py` (e.g. `benchmark_geglu.py`, `benchmark_swiglu.py`)
## 2. Use shared infrastructure
Do **not** hardcode batch size, sequence length, or model dimensions. Use:
| Need | Use |
|------|-----|
| Model dimensions (hidden_size, vocab_size, etc.) | `benchmark_model_configs.py`: `ModelConfig`, `get_benchmark_model_config()` |
| Safe sweep config (seq_len or hidden_size) | `compute_seq_len_sweep_config()` (returns `SeqLenSweepConfig`) or `compute_hidden_size_sweep_config()` (returns `HiddenSizeSweepConfig`), with optional `estimate_kernel_peak_memory()` |
| Speed / memory measurement | `utils.py`: `run_speed_benchmark()`, `run_memory_benchmark()` |
| CLI (overwrite, model choice) | `utils.py`: `parse_benchmark_script_args()` (includes `--model`) |
| Running the grid and writing CSV | `utils.py`: `run_benchmarks()` |
## 3. Script structure (three parts)
### 3.1 Setup factory
Define a single **setup function** that builds inputs and the layer (or callable) from `SingleBenchmarkRunInput`, so both speed and memory benchmarks reuse the same setup.
- **Signature**: `_setup_<kernel>(input: SingleBenchmarkRunInput) -> (tensors, layer_or_fn)`
- **Input**: `input.x` is the varying dimension (e.g. sequence length); `input.extra_benchmark_config` holds `bsz`, `hidden_size`, `dtype`, etc.; `input.kernel_provider` identifies the implementation variant (e.g. `"liger"`, `"huggingface"`, `"torch"`; values are kernel-specific).
- **Return**: Whatever the benchmark helpers need (e.g. `(x, layer)` for a single-tensor forward like GEGLU).
Example (conceptually):
```python
def _setup_geglu(input: SingleBenchmarkRunInput):
cfg = input.extra_benchmark_config
# Build config, create x tensor, instantiate LigerGEGLUMLP or LlamaMLP by provider
return x, layer
```
### 3.2 Speed and memory benchmark functions
Each takes `SingleBenchmarkRunInput` and returns `SingleBenchmarkRunOutput` by calling the shared helpers.
- **Speed**: `run_speed_benchmark(fwd_fn, mode, input_tensors, rep=...)`
- **Memory**: `run_memory_benchmark(fwd_fn, mode)`
- **Modes**: Use `["full", "forward", "backward"]` for both speed and memory for consistency.
Example:
```python
def bench_speed_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
x, layer = _setup_geglu(input)
return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x])
def bench_memory_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
x, layer = _setup_geglu(input)
return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode)
```
For **scalar output** (e.g. loss) or **multiple outputs** (e.g. RoPE), use the appropriate helpers from `utils.py` if available (e.g. loss or multi-output variants), or implement custom measurement and still use the same setup factory and `run_benchmarks()`.
### 3.3 `__main__`: model config, shape computation, run
1. Parse args: `args = parse_benchmark_script_args()` and resolve `model = get_benchmark_model_config(args.model)`.
2. (Recommended) Measure peak memory with a small probe using the **highest-memory baseline** implementation (e.g. `"huggingface"` or `"torch"`):
- Define a `_probe()` function that creates tensors/layers, runs a forward pass, and returns the output tensor. `_probe()` owns setup; `estimate_kernel_peak_memory` handles memory-stat reset before the call, runs `.backward()`, and performs cleanup (gc + cache clear) afterward.
- Call `peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe)`.
3. Compute sweep config (device memory is obtained internally by both helpers):
- **Sequence-length sweep** (e.g. GEGLU, SwiGLU): convert peak bytes to per-token (`kernel_bpt = peak_bytes // probe_seq_len`), then `config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt)`. The returned `SeqLenSweepConfig` has `batch_size` and `seq_len`.
- **Hidden-size sweep** (e.g. DyT): pass total peak bytes directly: `config = compute_hidden_size_sweep_config(model, kernel_peak_bytes=peak_bytes, bt=BT)`. The returned `HiddenSizeSweepConfig` has `bt` and `max_hidden_size`.
4. Build `x_values` from `config.seq_len` (seq_len sweep) or `config.max_hidden_size` (hidden_size sweep).
5. Build `extra_benchmark_configs` from `model` and config:
- Seq_len sweep: e.g. `bsz=config.batch_size`, `hidden_size=model.hidden_size`, `dtype=model.dtype`.
- Hidden_size sweep: e.g. `BT=config.bt`, `dtype=model.dtype`.
6. Call `run_benchmarks(..., kernel_operation_modes=["full", "forward", "backward"], ...)` for both speed and memory.
## 4. CLI
Scripts should support:
- `--overwrite`: overwrite existing rows in the benchmark CSV.
- `--model`: model profile name from `MODEL_REGISTRY` (e.g. `llama_2_7b`, `llama_3_8b`). Default when not set is `DEFAULT_MODEL_CONFIG` (e.g. `llama_3_8b`).
These are provided by `parse_benchmark_script_args()` in `utils.py`.
## 5. Reference scripts
- **Element-wise (single tensor in/out, seq_len sweep)**: `benchmark_geglu.py`, `benchmark_swiglu.py``compute_seq_len_sweep_config()`.
- **Element-wise (single tensor in/out, hidden_size sweep)**: `benchmark_dyt.py``compute_hidden_size_sweep_config()`.
## 6. Checklist for a new script
- [ ] Script under `benchmark/scripts/` named `benchmark_<kernel>.py`.
- [ ] Single `_setup_<kernel>(SingleBenchmarkRunInput)` used by both speed and memory.
- [ ] Speed/memory implemented via `run_speed_benchmark` / `run_memory_benchmark` (or the correct variant for loss / multi-output).
- [ ] `kernel_operation_modes=["full", "forward", "backward"]` for both speed and memory.
- [ ] No hardcoded batch size or sequence length; use `compute_seq_len_sweep_config()` or `compute_hidden_size_sweep_config()` (and optionally `estimate_kernel_peak_memory()`).
- [ ] Model dimensions and dtype from `ModelConfig` / `get_benchmark_model_config()` / `args.model`.
- [ ] CLI via `parse_benchmark_script_args()` (so `--model` and `--overwrite` work).
- [ ] Results written through `run_benchmarks()` so data goes to the shared CSV.
## Benchmarking Liger Kernels
Follow these steps to benchmark and visualize kernel performance:
1. Create a benchmark script
- Add your script under `benchmark/scripts/`
- Name it according to the kernel (e.g., `benchmark_<kernel_name>.py`)
2. Run the benchmark
- Results will be saved to `benchmark/data/all_benchmark_data.csv`
Example: Benchmarking KTO Loss
```bash
cd benchmark
python scripts/benchmark_kto_loss.py
```
3. Visualize results
- Use the visualization script with optional modes:
* To target specific mode(s), pass `--kernel-operation-mode` one or more values.
* If you omit `--kernel-operation-mode`, the script will:
- For `speed` metrics: generate plots for all available modes (forward/backward/full).
- For `memory` metrics: generate only the `full` plot.
Examples:
1. Specific modes (speed):
```bash
python benchmarks_visualizer.py \
--kernel-name kto_loss \
--metric-name speed \
--kernel-operation-mode forward backward
```
2. All modes (speed):
```bash
python benchmarks_visualizer.py \
--kernel-name kto_loss \
--metric-name speed
```
3. Memory (always full):
```bash
python benchmarks_visualizer.py \
--kernel-name kto_loss \
--metric-name memory
```
4. View results
- Generated plots will be saved in `benchmark/visualizations/`
\ No newline at end of file
import json
import os
import sys
from argparse import ArgumentParser
from dataclasses import dataclass
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
DATA_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "data/all_benchmark_data.csv"))
VISUALIZATIONS_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "visualizations/"))
@dataclass
class VisualizationsConfig:
"""
Configuration for the visualizations script.
Args:
kernel_name (str): Kernel name to benchmark. (Will run `scripts/benchmark_{kernel_name}.py`)
metric_name (str): Metric name to visualize (speed/memory)
kernel_operation_mode (str): Kernel operation mode to visualize (forward/backward/full). Defaults to "full"
extra_config_filter (str, optional): A string to filter extra_benchmark_config.
Can be a substring to match or a 'key=value' pair (e.g., "'H': 4096").
Defaults to None, which means the first available config will be used if multiple exist.
display (bool): Display the visualization. Defaults to False
overwrite (bool): Overwrite existing visualization, if none exist this flag has no effect as ones are always created and saved. Defaults to False
"""
kernel_name: str
metric_name: str
kernel_operation_mode: str = "full"
extra_config_filter: str | None = None
display: bool = False
overwrite: bool = False
def parse_args() -> VisualizationsConfig:
"""Parse command line arguments into a configuration object.
Returns:
VisualizationsConfig: Configuration object for the visualizations script.
"""
parser = ArgumentParser()
parser.add_argument("--kernel-name", type=str, required=True, help="Kernel name to benchmark")
parser.add_argument(
"--metric-name",
type=str,
required=True,
help="Metric name to visualize (speed/memory)",
)
parser.add_argument(
"--kernel-operation-mode",
type=str,
nargs="*",
default=None,
help="Kernel operation modes to visualize (forward/backward/full). If not provided, generate for all available modes.",
)
parser.add_argument(
"--extra-config-filter",
type=str,
default=None,
help="A string to filter extra_benchmark_config. "
"Can be a substring to match or a JSON-like 'key=value' pair (e.g., \"'H': 4096\" or \"H=4096\" for simple cases). "
"Defaults to None (first available config if multiple exist).",
)
parser.add_argument("--display", action="store_true", help="Display the visualization")
parser.add_argument(
"--overwrite",
action="store_true",
help="Overwrite existing visualization, if none exist this flag has no effect as one are always created",
)
args = parser.parse_args()
return args
def load_data(config: VisualizationsConfig) -> pd.DataFrame:
"""Loads the benchmark data from the CSV file and filters it based on the configuration.
Args:
config (VisualizationsConfig): Configuration object for the visualizations script.
Raises:
ValueError: If no data is found for the given filters.
Returns:
pd.DataFrame: Filtered benchmark dataframe.
"""
df = pd.read_csv(DATA_PATH)
df["extra_benchmark_config"] = df["extra_benchmark_config_str"].apply(json.loads)
base_filtered_df = df[
(df["kernel_name"] == config.kernel_name)
& (df["metric_name"] == config.metric_name)
& (df["kernel_operation_mode"] == config.kernel_operation_mode)
]
if base_filtered_df.empty:
raise ValueError(
f"No data found for kernel_name='{config.kernel_name}', "
f"metric_name='{config.metric_name}', "
f"kernel_operation_mode='{config.kernel_operation_mode}'."
)
unique_extra_configs_str = base_filtered_df["extra_benchmark_config_str"].unique()
selected_extra_config_str = None
if len(unique_extra_configs_str) == 0:
print(
"Warning: No extra_benchmark_config found for the initial filters. "
"Proceeding with all data from initial filter."
)
return base_filtered_df
if config.extra_config_filter:
matched_configs = []
try:
if "=" in config.extra_config_filter:
key_filter, value_filter = config.extra_config_filter.split("=", 1)
for cfg_str in unique_extra_configs_str:
cfg_json = json.loads(cfg_str)
if str(cfg_json.get(key_filter.strip("'\" "))) == value_filter.strip("'\" "):
matched_configs.append(cfg_str)
if not matched_configs:
matched_configs = [
cfg_str for cfg_str in unique_extra_configs_str if config.extra_config_filter in cfg_str
]
except Exception as e:
print(
f"Note: Could not parse extra_config_filter '{config.extra_config_filter}' as key=value ({e}), using substring match."
)
matched_configs = [cfg_str for cfg_str in unique_extra_configs_str if config.extra_config_filter in cfg_str]
if matched_configs:
if len(matched_configs) > 1:
print(
f"Warning: Multiple extra_benchmark_configs match filter '{config.extra_config_filter}': {matched_configs}. "
f"Using the first one: {matched_configs[0]}"
)
selected_extra_config_str = matched_configs[0]
else:
print(
f"Warning: No extra_benchmark_config matches filter '{config.extra_config_filter}'. "
f"Available configs for {config.kernel_name} ({config.metric_name}, {config.kernel_operation_mode}): {list(unique_extra_configs_str)}"
)
if len(unique_extra_configs_str) > 0:
selected_extra_config_str = unique_extra_configs_str[0]
print(f"Defaulting to the first available extra_benchmark_config: {selected_extra_config_str}")
else:
raise ValueError("No extra_benchmark_config available to select after failed filter attempt.")
elif len(unique_extra_configs_str) > 1:
selected_extra_config_str = unique_extra_configs_str[0]
print(
f"Warning: Multiple extra_benchmark_configs found for {config.kernel_name} ({config.metric_name}, {config.kernel_operation_mode})."
)
print(f"Defaulting to use: {selected_extra_config_str}")
print(f"Available configs: {list(unique_extra_configs_str)}")
print(
"Use the --extra-config-filter argument to select a specific one "
"(e.g., --extra-config-filter \"'H': 4096\" or a substring like \"'seq_len': 512\")."
)
elif len(unique_extra_configs_str) == 1:
selected_extra_config_str = unique_extra_configs_str[0]
print(f"Using unique extra_benchmark_config: {selected_extra_config_str}")
if selected_extra_config_str:
final_filtered_df = base_filtered_df[
base_filtered_df["extra_benchmark_config_str"] == selected_extra_config_str
]
else:
print("Warning: Could not select an extra_benchmark_config. Using data from initial filter if any.")
final_filtered_df = base_filtered_df
if final_filtered_df.empty:
raise ValueError(
f"No data found after attempting to filter by extra_benchmark_config. "
f"Selected/Defaulted extra_config_str: {selected_extra_config_str}"
if selected_extra_config_str
else "No specific extra_config was selected."
)
print(
f"Plotting data for extra_benchmark_config: {json.loads(selected_extra_config_str if selected_extra_config_str else '{}')}"
)
return final_filtered_df
def plot_data(df: pd.DataFrame, config: VisualizationsConfig):
"""Plots the benchmark data, saving the result if needed.
Args:
df (pd.DataFrame): Filtered benchmark dataframe.
config (VisualizationsConfig): Configuration object for the visualizations script.
"""
for col in ["y_value_20", "y_value_50", "y_value_80"]:
if col in df.columns:
df[col] = pd.to_numeric(df[col], errors="coerce")
xlabel = df["x_label"].iloc[0]
ylabel = f"{config.metric_name} ({df['metric_unit'].iloc[0]})"
# Sort by "kernel_provider" to ensure consistent color assignment
df = df.sort_values(by="kernel_provider")
plt.figure(figsize=(10, 6))
sns.set(style="whitegrid")
try:
ax = sns.lineplot(
data=df,
x="x_value",
y="y_value_50",
hue="kernel_provider",
marker="o",
palette="tab10",
errorbar=("ci", None),
)
except Exception:
ax = sns.lineplot(
data=df,
x="x_value",
y="y_value_50",
hue="kernel_provider",
marker="o",
palette="tab10",
errorbar=None,
)
# Seaborn can't plot pre-computed error bars, so we need to do it manually
lines = ax.get_lines()
colors = [line.get_color() for line in lines]
for (_, group_data), color in zip(df.groupby("kernel_provider"), colors):
# for i, row in group_data.iterrows():
y_error_lower = group_data["y_value_50"] - group_data["y_value_20"]
y_error_upper = group_data["y_value_80"] - group_data["y_value_50"]
y_error = [y_error_lower, y_error_upper]
plt.errorbar(
group_data["x_value"],
group_data["y_value_50"],
yerr=y_error,
fmt="o",
color=color,
capsize=5,
)
plt.legend(title="Kernel Provider")
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.tight_layout()
out_path = os.path.join(
VISUALIZATIONS_PATH,
f"{config.kernel_name}_{config.metric_name}_{config.kernel_operation_mode}.png",
)
if config.display:
plt.show()
if config.overwrite or not os.path.exists(
out_path
): # Save the plot if it doesn't exist or if we want to overwrite it
os.makedirs(VISUALIZATIONS_PATH, exist_ok=True)
plt.savefig(out_path)
plt.close()
def main():
args = parse_args()
all_df = pd.read_csv(DATA_PATH)
all_df["extra_benchmark_config"] = all_df["extra_benchmark_config_str"].apply(json.loads)
if args.metric_name == "memory":
modes = ["full"]
elif args.kernel_operation_mode:
modes = args.kernel_operation_mode
else:
filtered = all_df[(all_df["kernel_name"] == args.kernel_name) & (all_df["metric_name"] == args.metric_name)]
modes = filtered["kernel_operation_mode"].unique().tolist()
if not modes:
print(f"No data found for kernel '{args.kernel_name}' and metric '{args.metric_name}'.", file=sys.stderr)
sys.exit(1)
for mode in modes:
config = VisualizationsConfig(
kernel_name=args.kernel_name,
metric_name=args.metric_name,
kernel_operation_mode=mode,
display=args.display,
overwrite=args.overwrite,
)
df = load_data(config)
plot_data(df, config)
if __name__ == "__main__":
main()
This diff is collapsed.
import os
import sys
import torch
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.utils import infer_device
device = infer_device()
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
#############################################################################
# Test the memory consumption of the linear fused cross entropy loss
#############################################################################
def bench_memory_fused_linear_cpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_cpo_loss import LigerLMHeadCPO
from test.chunked_loss.test_cpo_loss import TorchLMHeadCPO
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider
# Instantiate once and retrieve the first output only
torch_lm_head_cpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_cpo = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
torch_fwd = lambda x, target: torch_lm_head_cpo(x, target)[0]
liger_fwd = lambda x, target: liger_lm_head_cpo(x, target)[0]
_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
def fwd():
if provider == "liger":
return liger_fwd(_input, target)
elif provider == "huggingface":
return torch_fwd(_input, target)
def full():
y = fwd()
y.backward()
mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
# #############################################################################
# # Test the speed of the fused linear cross entropy loss
# #############################################################################
def bench_speed_fused_linear_cpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_cpo_loss import LigerLMHeadCPO
from test.chunked_loss.test_cpo_loss import TorchLMHeadCPO
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider
mode = input.kernel_operation_mode
# Instantiate once and retrieve the first output only
torch_lm_head_cpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_cpo = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
torch_fwd = lambda x, target: torch_lm_head_cpo(x, target)[0]
liger_fwd = lambda x, target: liger_lm_head_cpo(x, target)[0]
_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
def fwd():
if provider == "liger":
return liger_fwd(_input, target)
elif provider == "huggingface":
return torch_fwd(_input, target)
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=[_input],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":
def full():
y = fwd()
y.backward()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
rep=100,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_configs = {
"kernel_name": "fused_linear_cpo_loss",
"x_name": "B",
"x_label": "B",
"x_values": [2**i for i in range(1, 5)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"T": 1024,
"H": 4096,
"V": 128256,
"mode": "forward",
"dtype": torch.bfloat16,
}
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_fused_linear_cpo_loss,
kernel_operation_modes=["forward", "backward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_fused_linear_cpo_loss,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
import torch
import triton
from torch.nn import CrossEntropyLoss
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.utils import infer_device
device = infer_device()
def bench_memory_cross_entropy(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
torch_ce = CrossEntropyLoss()
liger_ce = LigerCrossEntropyLoss()
V = input.x
provider = input.kernel_provider
B = input.extra_benchmark_config["B"]
T = input.extra_benchmark_config["T"]
_input = torch.randn(B * T, V, requires_grad=True, device=device)
target = torch.randint(V, (B * T, 1), device=device).squeeze(1)
def fwd():
if provider == "liger":
return liger_ce(_input, target)
else:
return torch_ce(_input, target)
def full():
y = fwd()
y.backward()
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
def bench_speed_cross_entropy(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
torch_ce = CrossEntropyLoss()
liger_ce = LigerCrossEntropyLoss()
V = input.x
provider = input.kernel_provider
mode = input.kernel_operation_mode
B = input.extra_benchmark_config["B"]
T = input.extra_benchmark_config["T"]
_input = torch.randn(B * T, V, requires_grad=True, device=device)
target = torch.randint(V, (B * T, 1), device=device).squeeze(1)
def fwd():
if provider == "liger":
return liger_ce(_input, target)
else:
return torch_ce(_input, target)
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES)
elif mode == "no-grad-forward":
with torch.no_grad():
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES)
elif mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=[_input],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":
def full():
y = fwd()
y.backward()
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_configs = {
"kernel_name": "cross_entropy",
"x_name": "V",
"x_label": "vocab size",
"x_values": [2**i for i in range(12, 18)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [{"B": 8, "T": 2048}],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_cross_entropy,
kernel_operation_modes=["forward", "backward", "full", "no-grad-forward"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_cross_entropy,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
import os
import sys
import torch
import torch.nn as nn
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityFunction
from liger_kernel.utils import infer_device
device = infer_device()
# Ensure the project root is in the path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
class TorchCosineSimilarityLoss(nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
bias: bool = False,
):
from test.chunked_loss.test_cosine_loss import HFCosineLoss
super().__init__()
self.student_lin = nn.Linear(in_features=H // 2, out_features=V, bias=bias).to(dtype=dtype)
self.teacher_lin = nn.Linear(in_features=H, out_features=V, bias=bias).to(dtype=dtype)
self.cosine_loss = HFCosineLoss(
ignore_index=ignore_index,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
temperature=temperature,
).get_batch_loss_metrics
def forward(self, student: torch.Tensor, teacher: torch.Tensor, target: torch.Tensor):
return self.cosine_loss(student, self.student_lin.weight, teacher, self.teacher_lin.weight, target)
class LigerCosineSimilarityLoss(nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
bias: bool = False,
):
super().__init__()
self.student_lin = nn.Linear(in_features=H // 2, out_features=V, bias=bias).to(dtype=dtype)
self.teacher_lin = nn.Linear(in_features=H, out_features=V, bias=bias).to(dtype=dtype)
self.weight_hard_loss = weight_hard_loss
self.weight_soft_loss = weight_soft_loss
self.ignore_index = ignore_index
self.temperature = temperature
self.cosine_loss = LigerFusedLinearCosineSimilarityFunction.apply
def forward(self, student: torch.Tensor, teacher: torch.Tensor, target: torch.Tensor):
return self.cosine_loss(
student,
self.student_lin.weight,
teacher,
self.teacher_lin.weight,
target,
self.student_lin.bias,
self.teacher_lin.bias,
self.weight_hard_loss,
self.weight_soft_loss,
)
def bench_memory_cosine_similarity_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
BT = input.x
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
bias = input.extra_benchmark_config["bias"]
weight_hard_loss = input.extra_benchmark_config["weight_hard_loss"]
weight_soft_loss = input.extra_benchmark_config["weight_soft_loss"]
ignore_index = input.extra_benchmark_config["ignore_index"]
provider = input.kernel_provider
torch_cosine_loss = TorchCosineSimilarityLoss(
H=H,
V=V,
dtype=dtype,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
bias=bias,
).to(device)
liger_cosine_loss = LigerCosineSimilarityLoss(
H=H,
V=V,
dtype=dtype,
ignore_index=ignore_index,
bias=bias,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
).to(device)
_tensor = torch.rand(BT, H // 2, device=device, dtype=dtype)
student_input1 = _tensor.detach().clone().requires_grad_(True)
student_input2 = _tensor.detach().clone().requires_grad_(True)
teacher_input = torch.rand(BT, H, device=device, dtype=dtype)
target = torch.randint(0, V, (BT,), device=device, dtype=torch.long)
def fwd():
if provider == "liger":
return liger_cosine_loss(student_input1, teacher_input, target)
elif provider == "torch":
return torch_cosine_loss(student_input2, teacher_input, target)
def full():
y = fwd()
y.backward()
mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
def bench_speed_cosine_similarity_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
BT = input.x
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
bias = input.extra_benchmark_config["bias"]
weight_hard_loss = input.extra_benchmark_config["weight_hard_loss"]
weight_soft_loss = input.extra_benchmark_config["weight_soft_loss"]
ignore_index = input.extra_benchmark_config["ignore_index"]
provider = input.kernel_provider
mode = input.kernel_operation_mode
torch_cosine_loss = TorchCosineSimilarityLoss(
H=H,
V=V,
dtype=dtype,
ignore_index=ignore_index,
bias=bias,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
).to(device)
liger_cosine_loss = LigerCosineSimilarityLoss(
H=H,
V=V,
dtype=dtype,
ignore_index=ignore_index,
bias=bias,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
).to(device)
_tensor = torch.rand(BT, H // 2, device=device, dtype=dtype)
student_input1 = _tensor.detach().clone().requires_grad_(True)
student_input2 = _tensor.detach().clone().requires_grad_(True)
teacher_input = torch.rand(BT, H, device=device, dtype=dtype)
target = torch.randint(0, V, (BT,), device=device, dtype=torch.long)
def fwd():
if provider == "liger":
return liger_cosine_loss(student_input1, teacher_input, target)
elif provider == "torch":
return torch_cosine_loss(student_input2, teacher_input, target)
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=[student_input1, student_input2],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":
def full():
y = fwd()
y.backward()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
rep=100,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_configs = {
"kernel_name": "distill_cosine_loss",
"x_name": "BT",
"x_label": "B x T",
"x_values": [2**i for i in range(10, 14)],
"kernel_providers": ["liger", "torch"],
"extra_benchmark_configs": [
{
"H": 4096,
"V": 128256,
"mode": "forward",
"dtype": torch.bfloat16,
"bias": False,
"weight_hard_loss": 0.5,
"weight_soft_loss": 0.5,
"ignore_index": -100,
}
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_cosine_similarity_loss,
kernel_operation_modes=["forward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_cosine_similarity_loss,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
import os
import sys
import torch
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
from liger_kernel.utils import get_total_gpu_memory
from liger_kernel.utils import infer_device
device = infer_device()
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
class TorchJSDLoss(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
bias: bool = False,
):
from test.chunked_loss.test_jsd_loss import HFJSDLoss
super().__init__()
self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=bias, dtype=dtype)
self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
self.jsd_loss = HFJSDLoss(
ignore_index=ignore_index,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
temperature=temperature,
).get_batch_loss_metrics
def forward(self, student, teacher, target):
return self.jsd_loss(
student,
self.student_lin.weight,
teacher,
self.teacher_lin.weight,
target,
)
class LigerJSDLoss(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
bias: bool = False,
):
super().__init__()
self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=bias, dtype=dtype)
self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
self.weight_hard_loss = weight_hard_loss
self.weight_soft_loss = weight_soft_loss
self.ignore_index = ignore_index
self.temperature = temperature
self.jsd_loss = LigerFusedLinearJSDFunction.apply
def forward(self, student, teacher, target):
return self.jsd_loss(
student,
self.student_lin.weight,
teacher,
self.teacher_lin.weight,
target,
self.student_lin.bias,
self.teacher_lin.bias,
self.weight_hard_loss,
self.weight_soft_loss,
)
def bench_memory_jsd_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
BT = input.x
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
bias = input.extra_benchmark_config["bias"]
weight_hard_loss = input.extra_benchmark_config["weight_hard_loss"]
weight_soft_loss = input.extra_benchmark_config["weight_soft_loss"]
ignore_index = input.extra_benchmark_config["ignore_index"]
provider = input.kernel_provider
torch_jsd_loss = TorchJSDLoss(
H=H,
V=V,
dtype=dtype,
ignore_index=ignore_index,
bias=bias,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
).to(device)
liger_jsd_loss = LigerJSDLoss(
H=H,
V=V,
dtype=dtype,
ignore_index=ignore_index,
bias=bias,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
).to(device)
_tensor = torch.rand(BT, H // 2, device=device, dtype=dtype)
student_input1 = _tensor.detach().clone().requires_grad_(True)
student_input2 = _tensor.detach().clone().requires_grad_(True)
teacher_input = torch.rand(BT, H, device=device, dtype=dtype)
target = torch.randint(0, V, (BT,), device=device, dtype=torch.long)
def fwd():
if provider == "liger":
return liger_jsd_loss(student_input1, teacher_input, target)
elif provider == "torch":
return torch_jsd_loss(student_input2, teacher_input, target)
def full():
y = fwd()
y.backward()
mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
def bench_speed_jsd_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
BT = input.x
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
bias = input.extra_benchmark_config["bias"]
weight_hard_loss = input.extra_benchmark_config["weight_hard_loss"]
weight_soft_loss = input.extra_benchmark_config["weight_soft_loss"]
ignore_index = input.extra_benchmark_config["ignore_index"]
provider = input.kernel_provider
mode = input.kernel_operation_mode
torch_jsd_loss = TorchJSDLoss(
H=H,
V=V,
dtype=dtype,
ignore_index=ignore_index,
bias=bias,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
).to(device)
liger_jsd_loss = LigerJSDLoss(
H=H,
V=V,
dtype=dtype,
ignore_index=ignore_index,
bias=bias,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
).to(device)
_tensor = torch.rand(BT, H // 2, device=device, dtype=dtype)
student_input1 = _tensor.detach().clone().requires_grad_(True)
student_input2 = _tensor.detach().clone().requires_grad_(True)
teacher_input = torch.rand(BT, H, device=device, dtype=dtype)
target = torch.randint(0, V, (BT,), device=device, dtype=torch.long)
def fwd():
if provider == "liger":
return liger_jsd_loss(student_input1, teacher_input, target)
elif provider == "torch":
return torch_jsd_loss(student_input2, teacher_input, target)
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=[student_input1, student_input2],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":
def full():
y = fwd()
y.backward()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
rep=100,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
gpu_memory_gbs = get_total_gpu_memory()
# We know that the full test will require 69GBs for vocab size 2^13 and 39GBs for vocab size 2^12 on torch
if gpu_memory_gbs >= 69:
x_max = 13
elif gpu_memory_gbs >= 39:
x_max = 12
else:
x_max = 11
common_configs = {
"kernel_name": "distill_jsd_loss",
"x_name": "BT",
"x_label": "B x T",
"x_values": [2**i for i in range(10, x_max + 1)],
"kernel_providers": ["liger", "torch"],
"extra_benchmark_configs": [
{
"H": 4096,
"V": 128256,
"mode": "forward",
"dtype": torch.bfloat16,
"bias": False,
"weight_hard_loss": 0.5,
"weight_soft_loss": 0.5,
"ignore_index": -100,
}
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_jsd_loss,
kernel_operation_modes=["forward", "backward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_jsd_loss,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
import os
import sys
import torch
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.utils import infer_device
device = infer_device()
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_dpo_loss import LigerLMHeadDPO
from test.chunked_loss.test_dpo_loss import TorchLMHeadDPO
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
bias = input.extra_benchmark_config["bias"]
beta = input.extra_benchmark_config["beta"]
ignore_index = input.extra_benchmark_config["ignore_index"]
provider = input.kernel_provider
# Instantiate once and retrieve the first output only
torch_dpo_loss = TorchLMHeadDPO(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device)
liger_dpo_loss = LigerLMHeadDPO(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device)
torch_fwd = lambda x, ref_x, target: torch_dpo_loss(x, ref_x, target)[0]
liger_fwd = lambda x, ref_x, target: liger_dpo_loss(x, ref_x, target)[0]
# Input shape: [B, T, H]
_input = torch.randn(B, T, H, device=device, dtype=dtype)
ref_input = torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False)
# Target shape: [B, T]
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
# Add ignore_index tokens to simulate padding
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
target.view(-1)[indices_to_assign] = ignore_index
def fwd():
if provider == "liger":
return liger_fwd(_input, ref_input, target)
elif provider == "huggingface":
return torch_fwd(_input, ref_input, target)
def full():
y = fwd()
y.backward()
mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_dpo_loss import LigerLMHeadDPO
from test.chunked_loss.test_dpo_loss import TorchLMHeadDPO
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
bias = input.extra_benchmark_config["bias"]
beta = input.extra_benchmark_config["beta"]
ignore_index = input.extra_benchmark_config["ignore_index"]
provider = input.kernel_provider
mode = input.kernel_operation_mode
# Instantiate once and retrieve the first output only
torch_dpo_loss = TorchLMHeadDPO(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device)
liger_dpo_loss = LigerLMHeadDPO(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device)
torch_fwd = lambda x, ref_x, target: torch_dpo_loss(x, ref_x, target)[0]
liger_fwd = lambda x, ref_x, target: liger_dpo_loss(x, ref_x, target)[0]
# Input shape: [B, T, H]
_input = torch.randn(B, T, H, device=device, dtype=dtype)
ref_input = torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False)
# Target shape: [B, T]
target = torch.randint(V, (B, T), device=device, dtype=torch.long)
# Add ignore_index tokens
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
target.view(-1)[indices_to_assign] = ignore_index
def fwd():
if provider == "liger":
return liger_fwd(_input, ref_input, target)
elif provider == "huggingface":
return torch_fwd(_input, ref_input, target)
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=[_input],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":
def full():
y = fwd()
y.backward()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
rep=100,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_configs = {
"kernel_name": "dpo_loss",
"x_name": "B",
"x_label": "Batch Size (B)",
"x_values": [2**i for i in range(1, 6)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"T": 512,
"H": 1024,
"V": 128256,
"mode": "forward",
"dtype": torch.bfloat16,
"bias": True,
"beta": 0.1,
"ignore_index": 42,
}
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_dpo_loss,
kernel_operation_modes=["forward", "backward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_dpo_loss,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
import os
import sys
import torch
from benchmark_model_configs import compute_hidden_size_sweep_config
from benchmark_model_configs import estimate_kernel_peak_memory
from benchmark_model_configs import get_benchmark_model_config
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from utils import run_memory_benchmark
from utils import run_speed_benchmark
from liger_kernel.utils import infer_device
device = infer_device()
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
def _setup_dyt(input: SingleBenchmarkRunInput):
"""Create input tensor and DyT layer from benchmark config."""
from test.transformers.test_dyt import LigerDyT
from test.transformers.test_dyt import TorchDyT
cfg = input.extra_benchmark_config
hidden_size = input.x
x = torch.randn(cfg["BT"], hidden_size, device=device, dtype=cfg["dtype"], requires_grad=True)
if input.kernel_provider == "liger":
layer = LigerDyT(hidden_size=hidden_size, beta=cfg["beta"]).to(device)
elif input.kernel_provider == "torch":
layer = TorchDyT(hidden_size=hidden_size, beta=cfg["beta"]).to(device)
elif input.kernel_provider == "torch_compile":
layer = torch.compile(TorchDyT(hidden_size=hidden_size, beta=cfg["beta"]).to(device))
else:
raise ValueError(f"Invalid provider: {input.kernel_provider} for DyT")
return x, layer
def bench_speed_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
x, layer = _setup_dyt(input)
return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x])
def bench_memory_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
x, layer = _setup_dyt(input)
return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode)
BT = 4096
if __name__ == "__main__":
args = parse_benchmark_script_args()
model = get_benchmark_model_config(args.model)
for beta in [False, True]:
def _probe():
probe_input = SingleBenchmarkRunInput(
x=model.hidden_size,
kernel_provider="torch",
extra_benchmark_config={"BT": BT, "dtype": model.dtype, "beta": beta},
)
x, layer = _setup_dyt(probe_input)
return layer(x)
peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe)
sweep_config = compute_hidden_size_sweep_config(model, peak_bytes, bt=BT)
x_values = [1024 * i for i in range(1, 17) if 1024 * i <= sweep_config.max_hidden_size] or [model.hidden_size]
common_configs = {
"kernel_name": f"dyt_beta={beta}",
"x_name": "hidden_size",
"x_label": "hidden_size",
"x_values": x_values,
"kernel_providers": ["liger", "torch", "torch_compile"],
"extra_benchmark_configs": [{"BT": sweep_config.bt, "dtype": model.dtype, "beta": beta}],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_dyt,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_dyt,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
import torch
import triton
from torch.nn import Embedding
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.transformers.experimental.embedding import LigerEmbedding
from liger_kernel.utils import infer_device
device = infer_device()
# NOTE: For torch compile, we will just use default inductor settings. No further customization
# is needed.
def bench_speed_embedding(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
V = input.x
provider = input.kernel_provider
mode = input.kernel_operation_mode
B = input.extra_benchmark_config["B"]
T = input.extra_benchmark_config["T"]
D = input.extra_benchmark_config["D"]
dtype = input.extra_benchmark_config["dtype"]
torch_emb = Embedding(V, D).to(device).to(dtype)
liger_emb = LigerEmbedding(V, D).to(device).to(dtype)
torch_compile_emb = torch.compile(torch_emb)
input_ids = torch.randint(0, V, (B, T), device=device)
def fwd():
if provider == "liger":
return liger_emb(input_ids)
elif provider == "torch_compile":
return torch_compile_emb(input_ids)
else:
return torch_emb(input_ids)
def full():
output = fwd()
output.backward(torch.randn_like(output))
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100)
elif mode == "backward":
output = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: output.backward(torch.randn_like(output), retain_graph=True),
quantiles=QUANTILES,
grad_to_none=[input_ids],
rep=100,
)
elif mode == "full":
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
def bench_memory_embedding(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
V = input.x
provider = input.kernel_provider
B = input.extra_benchmark_config["B"]
T = input.extra_benchmark_config["T"]
D = input.extra_benchmark_config["D"]
dtype = input.extra_benchmark_config["dtype"]
torch_emb = Embedding(V, D).to(device).to(dtype)
liger_emb = LigerEmbedding(V, D).to(device).to(dtype)
torch_compile_emb = torch.compile(torch_emb)
input_ids = torch.randint(0, V, (B, T), device=device)
def fwd():
if provider == "liger":
return liger_emb(input_ids)
elif provider == "torch_compile":
return torch_compile_emb(input_ids)
else:
return torch_emb(input_ids)
def full():
output = fwd()
output.backward(torch.randn_like(output))
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_configs = {
"kernel_name": "embedding",
"x_name": "V",
"x_label": "embedding dimension",
"x_values": [2**i for i in range(10, 18)],
"kernel_providers": ["liger", "huggingface", "torch_compile"],
"extra_benchmark_configs": [
# BERT
{"B": 32, "T": 512, "D": 768, "dtype": torch.float32},
# Llama
{"B": 8, "T": 2048, "D": 4096, "dtype": torch.float32},
],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_embedding,
kernel_operation_modes=["forward", "backward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_embedding,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
import torch
import torch.nn as nn
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.transformers.fused_add_rms_norm import LigerFusedAddRMSNorm
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.utils import infer_device
device = infer_device()
class NaiveAddRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Naive implementation of the add residual rms norm.
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states, residual):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
residual = residual.to(torch.float32)
hidden_states = hidden_states + residual
residual = hidden_states
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype), residual.to(input_dtype)
class AddLigerRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
AddLigerRMSNorm is equivalent to NaiveAddRMSNorm class above, but uses the LigerRMSNorm kernel.
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.rms_norm = LigerRMSNorm(hidden_size, eps, in_place=False)
def forward(self, hidden_states, residual):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
residual = residual.to(torch.float32)
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = self.rms_norm(hidden_states)
return self.weight * hidden_states.to(input_dtype), residual.to(input_dtype)
def bench_speed_fused_residual_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
N = input.x
provider = input.kernel_provider
mode = input.kernel_operation_mode
extra_benchmark_config = input.extra_benchmark_config
M = extra_benchmark_config["M"]
eps = extra_benchmark_config["eps"]
dtype = extra_benchmark_config["dtype"]
x_shape = (M, N)
# Fused Add RMS Norm
fused_add_rms_norm = LigerFusedAddRMSNorm(hidden_size=N, eps=eps).to(device)
# Naive implementation
naive_rms_norm = NaiveAddRMSNorm(hidden_size=N, eps=eps).to(device)
# LigerRMSNorm without fused residual addition
liger_rms_norm = AddLigerRMSNorm(hidden_size=N, eps=eps).to(device)
x = torch.randn(x_shape, dtype=dtype, device=device)
r = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
ds = torch.randn_like(r)
x.requires_grad_(True)
r.requires_grad_(True)
# utility functions
def y_fwd():
if provider == "liger_fused_add_rms_norm":
return fused_add_rms_norm(x, r)
if provider == "huggingface":
return naive_rms_norm(x, r)
if provider == "liger_rms_norm":
return liger_rms_norm(x, r)
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
y_fwd,
grad_to_none=[x, r],
rep=500,
quantiles=QUANTILES,
)
elif mode == "backward":
y, s = y_fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: torch.autograd.backward((y, s), (dy, ds), retain_graph=True),
grad_to_none=[x, r],
rep=500,
quantiles=QUANTILES,
)
elif mode == "full":
def full():
y, s = y_fwd()
torch.autograd.backward((y, s), (dy, ds))
ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
grad_to_none=[x, r],
rep=500,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
def bench_memory_fused_residual_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
N = input.x
provider = input.kernel_provider
extra_benchmark_config = input.extra_benchmark_config
M = extra_benchmark_config["M"]
eps = extra_benchmark_config["eps"]
dtype = extra_benchmark_config["dtype"]
x_shape = (M, N)
fused_add_rms_norm = LigerFusedAddRMSNorm(hidden_size=N, eps=eps).to(device)
naive_rms_norm = NaiveAddRMSNorm(hidden_size=N, eps=eps).to(device)
liger_rms_norm = AddLigerRMSNorm(hidden_size=N, eps=eps).to(device)
x = torch.randn(x_shape, dtype=dtype, device=device)
r = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
ds = torch.randn_like(r)
x.requires_grad_(True)
r.requires_grad_(True)
# utility functions
def y_fwd():
if provider == "liger_fused_add_rms_norm":
return fused_add_rms_norm(x, r)
if provider == "huggingface":
return naive_rms_norm(x, r)
if provider == "liger_rms_norm":
return liger_rms_norm(x, r)
def full():
y, s = y_fwd()
torch.autograd.backward((y, s), (dy, ds))
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_configs = {
"kernel_name": "fused_add_rms_norm",
"x_name": "H",
"x_label": "hidden size",
"x_values": [2**i for i in range(10, 16)],
"kernel_providers": ["liger_fused_add_rms_norm", "huggingface", "liger_rms_norm"],
"extra_benchmark_configs": [{"M": 2048, "dtype": torch.float32, "eps": 1e-6}],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_fused_residual_rms_norm,
kernel_operation_modes=["forward", "full", "backward"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_fused_residual_rms_norm,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
import torch
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
from liger_kernel.utils import infer_device
device = infer_device()
class TorchLMHeadCE(torch.nn.Module):
"""Ground truth implementation of the linear fused with torch based cross entropy loss.
:param H: hidden size
:param V: vocab size
:param ignore_index: index to ignore
:param reduction: reduction method
"""
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype)
self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction="mean")
def forward(self, x, y):
logits = self.lin(x)
return self.ce_loss(logits, y)
class LigerLMHeadCE(torch.nn.Module):
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100, accum_dtype=None):
super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype)
self.ce_loss = LigerFusedLinearCrossEntropyLoss(
ignore_index=ignore_index, reduction="mean", accum_dtype=accum_dtype
)
def forward(self, x, y):
return self.ce_loss(self.lin.weight, x, y)
#############################################################################
# Test the memory consumption of the linear fused cross entropy loss
#############################################################################
def bench_memory_fused_linear_cross_entropy(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
BT = input.x
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider
lm_head_ce = None
if provider == "liger":
lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)
elif provider == "liger-fp32-accum":
lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype, accum_dtype=torch.float32).to(device)
else:
lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)
_input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1)
def fwd():
return lm_head_ce(_input, target)
def full():
y = fwd()
y.backward()
mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
# #############################################################################
# # Test the speed of the fused linear cross entropy loss
# #############################################################################
def bench_speed_fused_linear_cross_entropy(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
BT = input.x
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider
mode = input.kernel_operation_mode
lm_head_ce = None
if provider == "liger":
lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)
elif provider == "liger-fp32-accum":
lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype, accum_dtype=torch.float32).to(device)
else:
lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)
_input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1)
def fwd():
return lm_head_ce(_input, target)
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "no-grad-forward":
with torch.no_grad():
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=[_input],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":
def full():
y = fwd()
y.backward()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
rep=100,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_configs = {
"kernel_name": "fused_linear_cross_entropy",
"x_name": "BT",
"x_label": "B x T",
"x_values": [2**i for i in range(12, 16)],
"kernel_providers": ["liger", "liger-fp32-accum", "huggingface"],
"extra_benchmark_configs": [{"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16}],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_fused_linear_cross_entropy,
kernel_operation_modes=["forward", "backward", "full", "no-grad-forward"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_fused_linear_cross_entropy,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
import torch
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD
from liger_kernel.utils import infer_device
device = infer_device()
class TorchJSD(torch.nn.Module):
def __init__(
self,
beta: float = 0.5,
ignore_index: int = -100,
dtype: torch.dtype = torch.float,
):
super(TorchJSD, self).__init__()
self.kl = torch.nn.KLDivLoss(reduction="none", log_target=True)
self.beta = beta
self.ignore_index = ignore_index
self.dtype = dtype
def forward(
self,
log_q: torch.Tensor, # input
log_p: torch.Tensor, # target
label=None,
):
log_p, log_q = log_p.to(torch.float), log_q.to(torch.float)
log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1))
m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta)
loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + (1 - self.beta) * self.kl(
torch.log(m), log_q
).sum(dim=-1)
if label is not None:
loss = torch.where(label != self.ignore_index, loss, 0.0)
n_non_ignore = (label != self.ignore_index).sum().item()
if n_non_ignore == 0:
loss = 0.0
else:
loss = (loss / n_non_ignore).sum()
else:
loss = (loss / log_q.shape[0]).sum()
return loss.to(self.dtype)
class TorchLMHeadJSD(torch.nn.Module):
"""Ground truth implementation of the linear fused with torch based jsd loss.
:param H: hidden size
:param V: vocab size
:param temperature: softmax temperature
:param beta: jsd beta
"""
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
device: torch.device,
beta: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
):
super().__init__()
self.student_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device)
self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device)
self.jsd = TorchJSD(beta=beta, ignore_index=ignore_index, dtype=dtype)
self.temperature = temperature
def forward(self, student_input, teacher_input, label=None):
student_logits = self.student_lin(student_input)
teacher_logits = self.teacher_lin(teacher_input)
student_prob = torch.log_softmax(student_logits / self.temperature, dim=-1)
teacher_prob = torch.log_softmax(teacher_logits / self.temperature, dim=-1)
return self.jsd(student_prob, teacher_prob, label)
class LigerLMHeadJSD(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
device: torch.device,
beta: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
):
super().__init__()
self.student_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device)
self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device)
self.fused_jsd = LigerFusedLinearJSD(jsd_beta=beta, ignore_index=ignore_index, temperature=temperature)
def forward(self, student_input, teacher_input, label=None):
return self.fused_jsd(
student_input,
self.student_lin.weight,
teacher_input,
self.teacher_lin.weight,
label,
)
#############################################################################
# Test the memory consumption of the fused linear JSD
#############################################################################
def bench_memory_fused_linear_jsd(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
BT = input.x
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider
torch_lm_head_jsd = TorchLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device)
liger_lm_head_jsd = LigerLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device)
# init the linear in all FusedLinearJSDs with the same weights
torch_lm_head_jsd.student_lin.weight.data = liger_lm_head_jsd.student_lin.weight.data = torch.rand(
V, H, device=device, dtype=dtype
)
torch_lm_head_jsd.teacher_lin.weight.data = liger_lm_head_jsd.teacher_lin.weight.data = torch.rand(
V, H, device=device, dtype=dtype
)
student_input = torch.rand(BT, H, requires_grad=True, dtype=dtype, device=device)
teacher_input = torch.rand(BT, H, dtype=dtype, device=device)
def fwd():
if provider == "liger":
return liger_lm_head_jsd(student_input, teacher_input)
elif provider == "torch":
return torch_lm_head_jsd(student_input, teacher_input)
def full():
y = fwd()
y.backward()
mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
# #############################################################################
# # Test the speed of the fused linear JSD
# #############################################################################
def bench_speed_fused_linear_jsd(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
BT = input.x
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
mode = input.kernel_operation_mode
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider
torch_lm_head_jsd = TorchLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device)
liger_lm_head_jsd = LigerLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device)
# init the linear in all FusedLinearJSDs with the same weights
torch_lm_head_jsd.student_lin.weight.data = liger_lm_head_jsd.student_lin.weight.data = torch.rand(
V, H, device=device, dtype=dtype
)
torch_lm_head_jsd.teacher_lin.weight.data = liger_lm_head_jsd.teacher_lin.weight.data = torch.rand(
V, H, device=device, dtype=dtype
)
student_input = torch.rand(BT, H, requires_grad=True, dtype=dtype, device=device)
teacher_input = torch.rand(BT, H, dtype=dtype, device=device)
def fwd():
if provider == "liger":
return liger_lm_head_jsd(student_input, teacher_input)
elif provider == "torch":
return torch_lm_head_jsd(student_input, teacher_input)
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=[
student_input,
torch_lm_head_jsd.student_lin.weight,
torch_lm_head_jsd.teacher_lin.weight,
],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":
def full():
y = fwd()
y.backward()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
rep=100,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_configs = {
"kernel_name": "fused_linear_jsd",
"x_name": "BT",
"x_label": "B x T",
"x_values": [2**i for i in range(10, 14)],
"kernel_providers": ["liger", "torch"],
"extra_benchmark_configs": [{"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16}],
"overwrite": args.overwrite,
}
run_benchmarks(
bench_test_fn=bench_speed_fused_linear_jsd,
kernel_operation_modes=["forward", "backward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_fused_linear_jsd,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
import math
import torch
import triton
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from liger_kernel.transformers.fused_neighborhood_attention import LigerFusedNeighborhoodAttention
from liger_kernel.utils import infer_device
device = infer_device()
class TorchNeighborhoodAttention(torch.nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
kernel_size: int = 7,
dilation: int = 1,
bias: bool = True,
dropout: float = 0.0,
scale: float = None,
):
super().__init__()
if hidden_size % num_heads != 0:
raise ValueError(f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})")
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.kernel_size = kernel_size
self.dilation = dilation
self.scale = scale if scale is not None else 1.0 / math.sqrt(self.head_dim)
self.q_proj = torch.nn.Linear(hidden_size, hidden_size, bias=bias)
self.k_proj = torch.nn.Linear(hidden_size, hidden_size, bias=bias)
self.v_proj = torch.nn.Linear(hidden_size, hidden_size, bias=bias)
self.out_proj = torch.nn.Linear(hidden_size, hidden_size, bias=bias)
if dropout > 0.0:
self.dropout = torch.nn.Dropout(dropout)
else:
self.dropout = None
def _create_neighborhood_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
mask = torch.zeros(seq_len, seq_len, device=device, dtype=torch.bool)
half_kernel = self.kernel_size // 2
for i in range(seq_len):
start = max(0, i - half_kernel * self.dilation)
end = min(seq_len, i + half_kernel * self.dilation + 1)
for j in range(start, end):
if self.dilation == 1 or (j - i) % self.dilation == 0:
mask[i, j] = True
return mask
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, hidden_size = hidden_states.shape
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale
mask = self._create_neighborhood_mask(seq_len, hidden_states.device)
scores = scores.masked_fill(~mask, float("-inf"))
attn_weights = torch.softmax(scores, dim=-1)
if self.dropout is not None:
attn_weights = self.dropout(attn_weights)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size)
output = self.out_proj(attn_output)
return output
def bench_speed_fused_neighborhood_attention(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
seq_len = input.x
provider = input.kernel_provider
mode = input.kernel_operation_mode
extra_benchmark_config = input.extra_benchmark_config
batch_size = extra_benchmark_config["batch_size"]
hidden_size = extra_benchmark_config["hidden_size"]
num_heads = extra_benchmark_config["num_heads"]
kernel_size = extra_benchmark_config["kernel_size"]
dilation = extra_benchmark_config["dilation"]
bias = extra_benchmark_config["bias"]
dtype = extra_benchmark_config["dtype"]
x_shape = (batch_size, seq_len, hidden_size)
liger_attn = (
LigerFusedNeighborhoodAttention(
hidden_size=hidden_size,
num_heads=num_heads,
kernel_size=kernel_size,
dilation=dilation,
bias=bias,
dropout=0.0,
)
.to(device)
.to(dtype)
)
torch_attn = (
TorchNeighborhoodAttention(
hidden_size=hidden_size,
num_heads=num_heads,
kernel_size=kernel_size,
dilation=dilation,
bias=bias,
dropout=0.0,
)
.to(device)
.to(dtype)
)
with torch.no_grad():
torch_attn.q_proj.weight.copy_(liger_attn.q_proj.weight)
torch_attn.k_proj.weight.copy_(liger_attn.k_proj.weight)
torch_attn.v_proj.weight.copy_(liger_attn.v_proj.weight)
torch_attn.out_proj.weight.copy_(liger_attn.out_proj.weight)
if bias:
torch_attn.q_proj.bias.copy_(liger_attn.q_proj.bias)
torch_attn.k_proj.bias.copy_(liger_attn.k_proj.bias)
torch_attn.v_proj.bias.copy_(liger_attn.v_proj.bias)
torch_attn.out_proj.bias.copy_(liger_attn.out_proj.bias)
x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)
def fwd():
if provider == "liger":
return liger_attn(x)
elif provider == "torch":
return torch_attn(x)
print(f"Starting Warmup for input size: {x_shape}")
_ = fwd()
if mode in ("backward", "full"):
y = _
y.backward(dy, retain_graph=True)
print("Done Warmup")
if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, grad_to_none=[x], rep=100, quantiles=QUANTILES)
elif mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(dy, retain_graph=True),
grad_to_none=[x],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":
def full():
y = fwd()
y.backward(dy, retain_graph=True)
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=[x], rep=100, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)
def bench_memory_fused_neighborhood_attention(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
seq_len = input.x
provider = input.kernel_provider
extra_benchmark_config = input.extra_benchmark_config
batch_size = extra_benchmark_config["batch_size"]
hidden_size = extra_benchmark_config["hidden_size"]
num_heads = extra_benchmark_config["num_heads"]
kernel_size = extra_benchmark_config["kernel_size"]
dilation = extra_benchmark_config["dilation"]
bias = extra_benchmark_config["bias"]
dtype = extra_benchmark_config["dtype"]
x_shape = (batch_size, seq_len, hidden_size)
liger_attn = (
LigerFusedNeighborhoodAttention(
hidden_size=hidden_size,
num_heads=num_heads,
kernel_size=kernel_size,
dilation=dilation,
bias=bias,
dropout=0.0,
)
.to(device)
.to(dtype)
)
torch_attn = (
TorchNeighborhoodAttention(
hidden_size=hidden_size,
num_heads=num_heads,
kernel_size=kernel_size,
dilation=dilation,
bias=bias,
dropout=0.0,
)
.to(device)
.to(dtype)
)
with torch.no_grad():
torch_attn.q_proj.weight.copy_(liger_attn.q_proj.weight)
torch_attn.k_proj.weight.copy_(liger_attn.k_proj.weight)
torch_attn.v_proj.weight.copy_(liger_attn.v_proj.weight)
torch_attn.out_proj.weight.copy_(liger_attn.out_proj.weight)
if bias:
torch_attn.q_proj.bias.copy_(liger_attn.q_proj.bias)
torch_attn.k_proj.bias.copy_(liger_attn.k_proj.bias)
torch_attn.v_proj.bias.copy_(liger_attn.v_proj.bias)
torch_attn.out_proj.bias.copy_(liger_attn.out_proj.bias)
x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)
def fwd():
if provider == "liger":
return liger_attn(x)
elif provider == "torch":
return torch_attn(x)
def full():
y = fwd()
y.backward(dy, retain_graph=True)
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)
if __name__ == "__main__":
args = parse_benchmark_script_args()
common_configs = {
"kernel_name": "fused_neighborhood_attention",
"x_name": "seq_len",
"x_label": "sequence length",
"x_values": [2**i for i in range(6, 13)],
"kernel_providers": ["liger", "torch"],
"extra_benchmark_configs": [
{
"batch_size": 2,
"hidden_size": 512,
"num_heads": 8,
"kernel_size": 7,
"dilation": 1,
"bias": True,
"dtype": torch.float32,
},
{
"batch_size": 4,
"hidden_size": 768,
"num_heads": 12,
"kernel_size": 7,
"dilation": 1,
"bias": True,
"dtype": torch.float32,
},
{
"batch_size": 2,
"hidden_size": 1024,
"num_heads": 16,
"kernel_size": 9,
"dilation": 1,
"bias": True,
"dtype": torch.float32,
},
{
"batch_size": 2,
"hidden_size": 512,
"num_heads": 8,
"kernel_size": 7,
"dilation": 2,
"bias": True,
"dtype": torch.float32,
},
{
"batch_size": 2,
"hidden_size": 512,
"num_heads": 8,
"kernel_size": 7,
"dilation": 1,
"bias": True,
"dtype": torch.bfloat16,
},
{
"batch_size": 4,
"hidden_size": 768,
"num_heads": 12,
"kernel_size": 7,
"dilation": 1,
"bias": True,
"dtype": torch.bfloat16,
},
{
"batch_size": 2,
"hidden_size": 1024,
"num_heads": 16,
"kernel_size": 9,
"dilation": 1,
"bias": True,
"dtype": torch.bfloat16,
},
{
"batch_size": 2,
"hidden_size": 512,
"num_heads": 8,
"kernel_size": 7,
"dilation": 2,
"bias": True,
"dtype": torch.bfloat16,
},
],
}
run_benchmarks(
bench_test_fn=bench_speed_fused_neighborhood_attention,
kernel_operation_modes=["forward", "full", "backward"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_fused_neighborhood_attention,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment