diff --git a/LICENSE b/LICENSE new file mode 100755 index 0000000000000000000000000000000000000000..d2fcc2b1c4384d0bcd1424b7f83db8e48fa753f6 --- /dev/null +++ b/LICENSE @@ -0,0 +1,23 @@ +BSD 2-CLAUSE LICENSE +Copyright 2024 LinkedIn Corporation +All Rights Reserved. +Redistribution and use in source and binary forms, with or +without modification, are permitted provided that the following +conditions are met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following +disclaimer in the documentation and/or other materials provided +with the distribution. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/Makefile b/Makefile new file mode 100755 index 0000000000000000000000000000000000000000..904d3c1492b7a20bb2a0e993404ed644c770b213 --- /dev/null +++ b/Makefile @@ -0,0 +1,73 @@ +.PHONY: test checkstyle test-convergence all serve build clean + + +all: checkstyle test test-convergence + +# Command to run pytest for correctness tests +test: + python -m pytest --disable-warnings \ + --cov=src/liger_kernel \ + --cov-report=term-missing \ + --ignore=test/convergence \ + test/ + +# Command to run coverage report +coverage: + coverage report -m + +# Command to run ruff for linting and formatting code +checkstyle: + ruff check --output-format=concise .; ruff_check_status=$$?; \ + ruff format --check --diff .; ruff_format_status=$$?; \ + ruff check . --fix; \ + ruff format .; \ + if [ $$ruff_check_status -ne 0 ] || [ $$ruff_format_status -ne 0 ]; then \ + exit 1; \ + fi + +# Command to run pytest for convergence tests +# We have to explicitly set HF_DATASETS_OFFLINE=1, or dataset will silently try to send metrics and timeout (80s) https://github.com/huggingface/datasets/blob/37a603679f451826cfafd8aae00738b01dcb9d58/src/datasets/load.py#L286 +test-convergence: + HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/fp32/test_mini_models.py + HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/fp32/test_mini_models_multimodal.py + HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/fp32/test_mini_models_with_logits.py + + HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models.py + HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models_multimodal.py + HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models_with_logits.py + +# Command to run all benchmark scripts and update benchmarking data file +# By default this doesn't overwrite existing data for the same benchmark experiment +# run with `make run-benchmarks OVERWRITE=1` to overwrite existing benchmark data +BENCHMARK_DIR = benchmark/scripts +BENCHMARK_SCRIPTS = $(wildcard $(BENCHMARK_DIR)/benchmark_*.py) +OVERWRITE ?= 0 + +run-benchmarks: + @for script in $(BENCHMARK_SCRIPTS); do \ + echo "Running benchmark: $$script"; \ + if [ $(OVERWRITE) -eq 1 ]; then \ + python $$script --overwrite; \ + else \ + python $$script; \ + fi; \ + done + +# MkDocs Configuration +MKDOCS = mkdocs +CONFIG_FILE = mkdocs.yml +SITE_DIR = site + +# MkDocs targets + +# Serve the documentation +serve: + $(MKDOCS) serve -f $(CONFIG_FILE) + +# Build the documentation into the specified site directory +build: + $(MKDOCS) build -f $(CONFIG_FILE) --site-dir $(SITE_DIR) + +# Clean the output directory +clean: + rm -rf $(SITE_DIR)/ diff --git a/NOTICE b/NOTICE new file mode 100755 index 0000000000000000000000000000000000000000..ea2881754f5b3e0eb9926dd9dc6c9d772f962911 --- /dev/null +++ b/NOTICE @@ -0,0 +1,58 @@ +Copyright 2024 LinkedIn Corporation +All Rights Reserved. + +Licensed under the BSD 2-Clause License (the "License"). See License in the project root for license information. + +This product includes software developed by LinkedIn Corporation. + +This product contains code derived from the following open source projects: + +1. Unsloth + Copyright (c) 2023 Unsloth AI + Licensed under the Apache License, Version 2.0 + Source: https://github.com/unslothai/unsloth + + The `calculate_settings` function to determine block size and warp is reused for Norm and MLP operations. + Modifications and additions were made to the RMS Norm implementation. + +2. Triton + Copyright (c) 2023 OpenAI + Licensed under the MIT License + Source: https://github.com/openai/triton + + Modifications were made based on Triton tutorials for the RMS Norm implementation. + +3. Efficient Cross Entropy + Copyright (c) 2023 Mohamed Malek + Licensed under the MIT License + Source: https://github.com/mgmalek/efficient_cross_entropy + + The idea of gradient-in-forward and chunking was used in the Linear Cross Entropy implementation. + +4. Flash Attention + Copyright (c) 2023 Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré + Licensed under the BSD 3-Clause License + Source: https://github.com/Dao-AILab/flash-attention + + Optimization ideas such as tiling and recomputation were inspired by this work. + +5. AutoAWQ + Copyright (c) 2023 Casper Hansen + Licensed under the MIT License + Source: https://github.com/casper-hansen/AutoAWQ + + The design of the automodel was referenced from this project. + +6. llm.c + Copyright (c) 2023 Andrej Karpathy + Licensed under the MIT License + Source: https://github.com/karpathy/llm.c + + The design of end-to-end testing was referenced from this project. + +7. Tiny Shakespeare Dataset + Source: https://huggingface.co/datasets/karpathy/tiny_shakespeare + + This dataset is used to conduct convergence tests on mini models. + +For full license texts, please refer to the respective project repositories. diff --git a/benchmark/BENCHMARK_GUIDELINES.md b/benchmark/BENCHMARK_GUIDELINES.md new file mode 100755 index 0000000000000000000000000000000000000000..907223430151540d36acf7ac73509a1252f2ca65 --- /dev/null +++ b/benchmark/BENCHMARK_GUIDELINES.md @@ -0,0 +1,101 @@ +# Guideline for Adding Benchmark Scripts + +This document describes how to add new benchmark scripts to Liger-Kernel in line with the shared framework. + +## 1. Where and how to add a script + +- **Location**: `benchmark/scripts/` +- **Naming**: `benchmark_.py` (e.g. `benchmark_geglu.py`, `benchmark_swiglu.py`) + +## 2. Use shared infrastructure + +Do **not** hardcode batch size, sequence length, or model dimensions. Use: + +| Need | Use | +|------|-----| +| Model dimensions (hidden_size, vocab_size, etc.) | `benchmark_model_configs.py`: `ModelConfig`, `get_benchmark_model_config()` | +| Safe sweep config (seq_len or hidden_size) | `compute_seq_len_sweep_config()` (returns `SeqLenSweepConfig`) or `compute_hidden_size_sweep_config()` (returns `HiddenSizeSweepConfig`), with optional `estimate_kernel_peak_memory()` | +| Speed / memory measurement | `utils.py`: `run_speed_benchmark()`, `run_memory_benchmark()` | +| CLI (overwrite, model choice) | `utils.py`: `parse_benchmark_script_args()` (includes `--model`) | +| Running the grid and writing CSV | `utils.py`: `run_benchmarks()` | + +## 3. Script structure (three parts) + +### 3.1 Setup factory + +Define a single **setup function** that builds inputs and the layer (or callable) from `SingleBenchmarkRunInput`, so both speed and memory benchmarks reuse the same setup. + +- **Signature**: `_setup_(input: SingleBenchmarkRunInput) -> (tensors, layer_or_fn)` +- **Input**: `input.x` is the varying dimension (e.g. sequence length); `input.extra_benchmark_config` holds `bsz`, `hidden_size`, `dtype`, etc.; `input.kernel_provider` identifies the implementation variant (e.g. `"liger"`, `"huggingface"`, `"torch"`; values are kernel-specific). +- **Return**: Whatever the benchmark helpers need (e.g. `(x, layer)` for a single-tensor forward like GEGLU). + +Example (conceptually): + +```python +def _setup_geglu(input: SingleBenchmarkRunInput): + cfg = input.extra_benchmark_config + # Build config, create x tensor, instantiate LigerGEGLUMLP or LlamaMLP by provider + return x, layer +``` + +### 3.2 Speed and memory benchmark functions + +Each takes `SingleBenchmarkRunInput` and returns `SingleBenchmarkRunOutput` by calling the shared helpers. + +- **Speed**: `run_speed_benchmark(fwd_fn, mode, input_tensors, rep=...)` +- **Memory**: `run_memory_benchmark(fwd_fn, mode)` +- **Modes**: Use `["full", "forward", "backward"]` for both speed and memory for consistency. + +Example: + +```python +def bench_speed_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, layer = _setup_geglu(input) + return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x]) + +def bench_memory_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, layer = _setup_geglu(input) + return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) +``` + +For **scalar output** (e.g. loss) or **multiple outputs** (e.g. RoPE), use the appropriate helpers from `utils.py` if available (e.g. loss or multi-output variants), or implement custom measurement and still use the same setup factory and `run_benchmarks()`. + +### 3.3 `__main__`: model config, shape computation, run + +1. Parse args: `args = parse_benchmark_script_args()` and resolve `model = get_benchmark_model_config(args.model)`. +2. (Recommended) Measure peak memory with a small probe using the **highest-memory baseline** implementation (e.g. `"huggingface"` or `"torch"`): + - Define a `_probe()` function that creates tensors/layers, runs a forward pass, and returns the output tensor. `_probe()` owns setup; `estimate_kernel_peak_memory` handles memory-stat reset before the call, runs `.backward()`, and performs cleanup (gc + cache clear) afterward. + - Call `peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe)`. +3. Compute sweep config (device memory is obtained internally by both helpers): + - **Sequence-length sweep** (e.g. GEGLU, SwiGLU): convert peak bytes to per-token (`kernel_bpt = peak_bytes // probe_seq_len`), then `config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt)`. The returned `SeqLenSweepConfig` has `batch_size` and `seq_len`. + - **Hidden-size sweep** (e.g. DyT): pass total peak bytes directly: `config = compute_hidden_size_sweep_config(model, kernel_peak_bytes=peak_bytes, bt=BT)`. The returned `HiddenSizeSweepConfig` has `bt` and `max_hidden_size`. +4. Build `x_values` from `config.seq_len` (seq_len sweep) or `config.max_hidden_size` (hidden_size sweep). +5. Build `extra_benchmark_configs` from `model` and config: + - Seq_len sweep: e.g. `bsz=config.batch_size`, `hidden_size=model.hidden_size`, `dtype=model.dtype`. + - Hidden_size sweep: e.g. `BT=config.bt`, `dtype=model.dtype`. +6. Call `run_benchmarks(..., kernel_operation_modes=["full", "forward", "backward"], ...)` for both speed and memory. + +## 4. CLI + +Scripts should support: + +- `--overwrite`: overwrite existing rows in the benchmark CSV. +- `--model`: model profile name from `MODEL_REGISTRY` (e.g. `llama_2_7b`, `llama_3_8b`). Default when not set is `DEFAULT_MODEL_CONFIG` (e.g. `llama_3_8b`). + +These are provided by `parse_benchmark_script_args()` in `utils.py`. + +## 5. Reference scripts + +- **Element-wise (single tensor in/out, seq_len sweep)**: `benchmark_geglu.py`, `benchmark_swiglu.py` — `compute_seq_len_sweep_config()`. +- **Element-wise (single tensor in/out, hidden_size sweep)**: `benchmark_dyt.py` — `compute_hidden_size_sweep_config()`. + +## 6. Checklist for a new script + +- [ ] Script under `benchmark/scripts/` named `benchmark_.py`. +- [ ] Single `_setup_(SingleBenchmarkRunInput)` used by both speed and memory. +- [ ] Speed/memory implemented via `run_speed_benchmark` / `run_memory_benchmark` (or the correct variant for loss / multi-output). +- [ ] `kernel_operation_modes=["full", "forward", "backward"]` for both speed and memory. +- [ ] No hardcoded batch size or sequence length; use `compute_seq_len_sweep_config()` or `compute_hidden_size_sweep_config()` (and optionally `estimate_kernel_peak_memory()`). +- [ ] Model dimensions and dtype from `ModelConfig` / `get_benchmark_model_config()` / `args.model`. +- [ ] CLI via `parse_benchmark_script_args()` (so `--model` and `--overwrite` work). +- [ ] Results written through `run_benchmarks()` so data goes to the shared CSV. diff --git a/benchmark/README.md b/benchmark/README.md new file mode 100755 index 0000000000000000000000000000000000000000..02c883d9215de7dbe38174c46deb1edd2bb01d4f --- /dev/null +++ b/benchmark/README.md @@ -0,0 +1,48 @@ +## Benchmarking Liger Kernels + +Follow these steps to benchmark and visualize kernel performance: + +1. Create a benchmark script + - Add your script under `benchmark/scripts/` + - Name it according to the kernel (e.g., `benchmark_.py`) + +2. Run the benchmark + - Results will be saved to `benchmark/data/all_benchmark_data.csv` + + Example: Benchmarking KTO Loss + ```bash + cd benchmark + python scripts/benchmark_kto_loss.py + ``` + +3. Visualize results + - Use the visualization script with optional modes: + + * To target specific mode(s), pass `--kernel-operation-mode` one or more values. + * If you omit `--kernel-operation-mode`, the script will: + - For `speed` metrics: generate plots for all available modes (forward/backward/full). + - For `memory` metrics: generate only the `full` plot. + + Examples: + 1. Specific modes (speed): + ```bash + python benchmarks_visualizer.py \ + --kernel-name kto_loss \ + --metric-name speed \ + --kernel-operation-mode forward backward + ``` + 2. All modes (speed): + ```bash + python benchmarks_visualizer.py \ + --kernel-name kto_loss \ + --metric-name speed + ``` + 3. Memory (always full): + ```bash + python benchmarks_visualizer.py \ + --kernel-name kto_loss \ + --metric-name memory + ``` + +4. View results + - Generated plots will be saved in `benchmark/visualizations/` \ No newline at end of file diff --git a/benchmark/__init__.py b/benchmark/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/benchmark/benchmarks_visualizer.py b/benchmark/benchmarks_visualizer.py new file mode 100755 index 0000000000000000000000000000000000000000..e33d844eaeba7a945b660fd4619183e3689226e4 --- /dev/null +++ b/benchmark/benchmarks_visualizer.py @@ -0,0 +1,299 @@ +import json +import os +import sys + +from argparse import ArgumentParser +from dataclasses import dataclass + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns + +DATA_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "data/all_benchmark_data.csv")) +VISUALIZATIONS_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "visualizations/")) + + +@dataclass +class VisualizationsConfig: + """ + Configuration for the visualizations script. + + Args: + kernel_name (str): Kernel name to benchmark. (Will run `scripts/benchmark_{kernel_name}.py`) + metric_name (str): Metric name to visualize (speed/memory) + kernel_operation_mode (str): Kernel operation mode to visualize (forward/backward/full). Defaults to "full" + extra_config_filter (str, optional): A string to filter extra_benchmark_config. + Can be a substring to match or a 'key=value' pair (e.g., "'H': 4096"). + Defaults to None, which means the first available config will be used if multiple exist. + display (bool): Display the visualization. Defaults to False + overwrite (bool): Overwrite existing visualization, if none exist this flag has no effect as ones are always created and saved. Defaults to False + + """ + + kernel_name: str + metric_name: str + kernel_operation_mode: str = "full" + extra_config_filter: str | None = None + display: bool = False + overwrite: bool = False + + +def parse_args() -> VisualizationsConfig: + """Parse command line arguments into a configuration object. + + Returns: + VisualizationsConfig: Configuration object for the visualizations script. + """ + parser = ArgumentParser() + parser.add_argument("--kernel-name", type=str, required=True, help="Kernel name to benchmark") + parser.add_argument( + "--metric-name", + type=str, + required=True, + help="Metric name to visualize (speed/memory)", + ) + parser.add_argument( + "--kernel-operation-mode", + type=str, + nargs="*", + default=None, + help="Kernel operation modes to visualize (forward/backward/full). If not provided, generate for all available modes.", + ) + parser.add_argument( + "--extra-config-filter", + type=str, + default=None, + help="A string to filter extra_benchmark_config. " + "Can be a substring to match or a JSON-like 'key=value' pair (e.g., \"'H': 4096\" or \"H=4096\" for simple cases). " + "Defaults to None (first available config if multiple exist).", + ) + parser.add_argument("--display", action="store_true", help="Display the visualization") + parser.add_argument( + "--overwrite", + action="store_true", + help="Overwrite existing visualization, if none exist this flag has no effect as one are always created", + ) + + args = parser.parse_args() + return args + + +def load_data(config: VisualizationsConfig) -> pd.DataFrame: + """Loads the benchmark data from the CSV file and filters it based on the configuration. + + Args: + config (VisualizationsConfig): Configuration object for the visualizations script. + + Raises: + ValueError: If no data is found for the given filters. + + Returns: + pd.DataFrame: Filtered benchmark dataframe. + """ + df = pd.read_csv(DATA_PATH) + df["extra_benchmark_config"] = df["extra_benchmark_config_str"].apply(json.loads) + + base_filtered_df = df[ + (df["kernel_name"] == config.kernel_name) + & (df["metric_name"] == config.metric_name) + & (df["kernel_operation_mode"] == config.kernel_operation_mode) + ] + + if base_filtered_df.empty: + raise ValueError( + f"No data found for kernel_name='{config.kernel_name}', " + f"metric_name='{config.metric_name}', " + f"kernel_operation_mode='{config.kernel_operation_mode}'." + ) + + unique_extra_configs_str = base_filtered_df["extra_benchmark_config_str"].unique() + selected_extra_config_str = None + + if len(unique_extra_configs_str) == 0: + print( + "Warning: No extra_benchmark_config found for the initial filters. " + "Proceeding with all data from initial filter." + ) + return base_filtered_df + + if config.extra_config_filter: + matched_configs = [] + try: + if "=" in config.extra_config_filter: + key_filter, value_filter = config.extra_config_filter.split("=", 1) + for cfg_str in unique_extra_configs_str: + cfg_json = json.loads(cfg_str) + if str(cfg_json.get(key_filter.strip("'\" "))) == value_filter.strip("'\" "): + matched_configs.append(cfg_str) + if not matched_configs: + matched_configs = [ + cfg_str for cfg_str in unique_extra_configs_str if config.extra_config_filter in cfg_str + ] + except Exception as e: + print( + f"Note: Could not parse extra_config_filter '{config.extra_config_filter}' as key=value ({e}), using substring match." + ) + matched_configs = [cfg_str for cfg_str in unique_extra_configs_str if config.extra_config_filter in cfg_str] + + if matched_configs: + if len(matched_configs) > 1: + print( + f"Warning: Multiple extra_benchmark_configs match filter '{config.extra_config_filter}': {matched_configs}. " + f"Using the first one: {matched_configs[0]}" + ) + selected_extra_config_str = matched_configs[0] + else: + print( + f"Warning: No extra_benchmark_config matches filter '{config.extra_config_filter}'. " + f"Available configs for {config.kernel_name} ({config.metric_name}, {config.kernel_operation_mode}): {list(unique_extra_configs_str)}" + ) + if len(unique_extra_configs_str) > 0: + selected_extra_config_str = unique_extra_configs_str[0] + print(f"Defaulting to the first available extra_benchmark_config: {selected_extra_config_str}") + else: + raise ValueError("No extra_benchmark_config available to select after failed filter attempt.") + + elif len(unique_extra_configs_str) > 1: + selected_extra_config_str = unique_extra_configs_str[0] + print( + f"Warning: Multiple extra_benchmark_configs found for {config.kernel_name} ({config.metric_name}, {config.kernel_operation_mode})." + ) + print(f"Defaulting to use: {selected_extra_config_str}") + print(f"Available configs: {list(unique_extra_configs_str)}") + print( + "Use the --extra-config-filter argument to select a specific one " + "(e.g., --extra-config-filter \"'H': 4096\" or a substring like \"'seq_len': 512\")." + ) + elif len(unique_extra_configs_str) == 1: + selected_extra_config_str = unique_extra_configs_str[0] + print(f"Using unique extra_benchmark_config: {selected_extra_config_str}") + + if selected_extra_config_str: + final_filtered_df = base_filtered_df[ + base_filtered_df["extra_benchmark_config_str"] == selected_extra_config_str + ] + else: + print("Warning: Could not select an extra_benchmark_config. Using data from initial filter if any.") + final_filtered_df = base_filtered_df + + if final_filtered_df.empty: + raise ValueError( + f"No data found after attempting to filter by extra_benchmark_config. " + f"Selected/Defaulted extra_config_str: {selected_extra_config_str}" + if selected_extra_config_str + else "No specific extra_config was selected." + ) + + print( + f"Plotting data for extra_benchmark_config: {json.loads(selected_extra_config_str if selected_extra_config_str else '{}')}" + ) + return final_filtered_df + + +def plot_data(df: pd.DataFrame, config: VisualizationsConfig): + """Plots the benchmark data, saving the result if needed. + + Args: + df (pd.DataFrame): Filtered benchmark dataframe. + config (VisualizationsConfig): Configuration object for the visualizations script. + """ + for col in ["y_value_20", "y_value_50", "y_value_80"]: + if col in df.columns: + df[col] = pd.to_numeric(df[col], errors="coerce") + + xlabel = df["x_label"].iloc[0] + ylabel = f"{config.metric_name} ({df['metric_unit'].iloc[0]})" + # Sort by "kernel_provider" to ensure consistent color assignment + df = df.sort_values(by="kernel_provider") + + plt.figure(figsize=(10, 6)) + sns.set(style="whitegrid") + try: + ax = sns.lineplot( + data=df, + x="x_value", + y="y_value_50", + hue="kernel_provider", + marker="o", + palette="tab10", + errorbar=("ci", None), + ) + except Exception: + ax = sns.lineplot( + data=df, + x="x_value", + y="y_value_50", + hue="kernel_provider", + marker="o", + palette="tab10", + errorbar=None, + ) + + # Seaborn can't plot pre-computed error bars, so we need to do it manually + lines = ax.get_lines() + colors = [line.get_color() for line in lines] + + for (_, group_data), color in zip(df.groupby("kernel_provider"), colors): + # for i, row in group_data.iterrows(): + y_error_lower = group_data["y_value_50"] - group_data["y_value_20"] + y_error_upper = group_data["y_value_80"] - group_data["y_value_50"] + y_error = [y_error_lower, y_error_upper] + + plt.errorbar( + group_data["x_value"], + group_data["y_value_50"], + yerr=y_error, + fmt="o", + color=color, + capsize=5, + ) + plt.legend(title="Kernel Provider") + plt.xlabel(xlabel) + plt.ylabel(ylabel) + plt.tight_layout() + + out_path = os.path.join( + VISUALIZATIONS_PATH, + f"{config.kernel_name}_{config.metric_name}_{config.kernel_operation_mode}.png", + ) + + if config.display: + plt.show() + if config.overwrite or not os.path.exists( + out_path + ): # Save the plot if it doesn't exist or if we want to overwrite it + os.makedirs(VISUALIZATIONS_PATH, exist_ok=True) + plt.savefig(out_path) + plt.close() + + +def main(): + args = parse_args() + all_df = pd.read_csv(DATA_PATH) + all_df["extra_benchmark_config"] = all_df["extra_benchmark_config_str"].apply(json.loads) + + if args.metric_name == "memory": + modes = ["full"] + elif args.kernel_operation_mode: + modes = args.kernel_operation_mode + else: + filtered = all_df[(all_df["kernel_name"] == args.kernel_name) & (all_df["metric_name"] == args.metric_name)] + modes = filtered["kernel_operation_mode"].unique().tolist() + if not modes: + print(f"No data found for kernel '{args.kernel_name}' and metric '{args.metric_name}'.", file=sys.stderr) + sys.exit(1) + + for mode in modes: + config = VisualizationsConfig( + kernel_name=args.kernel_name, + metric_name=args.metric_name, + kernel_operation_mode=mode, + display=args.display, + overwrite=args.overwrite, + ) + df = load_data(config) + plot_data(df, config) + + +if __name__ == "__main__": + main() diff --git a/benchmark/data/all_benchmark_data.csv b/benchmark/data/all_benchmark_data.csv new file mode 100755 index 0000000000000000000000000000000000000000..f63286a16a0e577b7a51bb672706cb713172f024 --- /dev/null +++ b/benchmark/data/all_benchmark_data.csv @@ -0,0 +1,1957 @@ +kernel_name,kernel_provider,kernel_operation_mode,metric_name,metric_unit,x_name,x_label,x_value,y_value_50,y_value_20,y_value_80,extra_benchmark_config_str,gpu_name,timestamp,liger_version +cross_entropy,liger,forward,speed,ms,V,vocab size,4096,0.5324159860610962,0.5291008353233337,0.53476482629776,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:39,0.2.1 +cross_entropy,liger,forward,speed,ms,V,vocab size,8192,0.8101439476013184,0.7565760016441345,0.9144319891929626,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:39,0.2.1 +cross_entropy,liger,forward,speed,ms,V,vocab size,16384,1.4320800304412842,1.4087040424346924,1.5254720449447632,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:39,0.2.1 +cross_entropy,liger,forward,speed,ms,V,vocab size,32768,2.8378241062164307,2.805759906768799,2.9447360038757324,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:39,0.2.1 +cross_entropy,liger,forward,speed,ms,V,vocab size,65536,6.805135726928711,6.790579319000244,6.98748779296875,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:39,0.2.1 +cross_entropy,liger,forward,speed,ms,V,vocab size,131072,15.009359359741211,15.00483226776123,15.045599937438965,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:39,0.2.1 +cross_entropy,huggingface,forward,speed,ms,V,vocab size,4096,0.8751360177993774,0.87330561876297,0.8773248195648193,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:40,0.2.1 +cross_entropy,huggingface,forward,speed,ms,V,vocab size,8192,1.188480019569397,1.1871488094329834,1.1901824474334717,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:40,0.2.1 +cross_entropy,huggingface,forward,speed,ms,V,vocab size,16384,1.9522240161895752,1.9451839923858643,1.962073564529419,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:40,0.2.1 +cross_entropy,huggingface,forward,speed,ms,V,vocab size,32768,5.316768169403076,5.314131259918213,5.319046497344971,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:40,0.2.1 +cross_entropy,huggingface,forward,speed,ms,V,vocab size,65536,10.615103721618652,10.607129096984863,10.61723518371582,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:40,0.2.1 +cross_entropy,huggingface,forward,speed,ms,V,vocab size,131072,20.72643280029297,20.72038459777832,20.758554458618164,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:40,0.2.1 +cross_entropy,liger,full,speed,ms,V,vocab size,4096,0.8637440204620361,0.8607680201530457,0.8670976161956787,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:41,0.2.1 +cross_entropy,liger,full,speed,ms,V,vocab size,8192,1.462272047996521,1.4576319456100464,1.4661248922348022,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:41,0.2.1 +cross_entropy,liger,full,speed,ms,V,vocab size,16384,2.7454559803009033,2.741612672805786,2.780428647994995,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:41,0.2.1 +cross_entropy,liger,full,speed,ms,V,vocab size,32768,5.403264045715332,5.398873329162598,5.4122114181518555,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:41,0.2.1 +cross_entropy,liger,full,speed,ms,V,vocab size,65536,11.925024032592773,11.919878005981445,11.92919635772705,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:41,0.2.1 +cross_entropy,liger,full,speed,ms,V,vocab size,131072,25.22287940979004,25.21867561340332,25.23493766784668,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:41,0.2.1 +cross_entropy,huggingface,full,speed,ms,V,vocab size,4096,2.2260000705718994,2.2239038944244385,2.2290303707122803,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:42,0.2.1 +cross_entropy,huggingface,full,speed,ms,V,vocab size,8192,3.5976319313049316,3.595616102218628,3.6007039546966553,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:42,0.2.1 +cross_entropy,huggingface,full,speed,ms,V,vocab size,16384,6.8023200035095215,6.795276641845703,6.806528091430664,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:42,0.2.1 +cross_entropy,huggingface,full,speed,ms,V,vocab size,32768,15.486032485961914,15.483936309814453,15.48681640625,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:42,0.2.1 +cross_entropy,huggingface,full,speed,ms,V,vocab size,65536,30.778079986572266,30.76335334777832,30.77827262878418,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:42,0.2.1 +cross_entropy,huggingface,full,speed,ms,V,vocab size,131072,60.43830490112305,60.43830490112305,60.43830490112305,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:42,0.2.1 +cross_entropy,liger,full,memory,MB,V,vocab size,4096,256.32861328125,256.32861328125,256.32861328125,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:42,0.2.1 +cross_entropy,liger,full,memory,MB,V,vocab size,8192,512.32861328125,512.32861328125,512.32861328125,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:42,0.2.1 +cross_entropy,liger,full,memory,MB,V,vocab size,16384,1024.32861328125,1024.32861328125,1024.32861328125,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:42,0.2.1 +cross_entropy,liger,full,memory,MB,V,vocab size,32768,2048.32861328125,2048.32861328125,2048.32861328125,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:42,0.2.1 +cross_entropy,liger,full,memory,MB,V,vocab size,65536,4096.32861328125,4096.32861328125,4096.32861328125,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:42,0.2.1 +cross_entropy,liger,full,memory,MB,V,vocab size,131072,8192.328125,8192.328125,8192.328125,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:42,0.2.1 +cross_entropy,huggingface,full,memory,MB,V,vocab size,4096,1280.1259765625,1280.1259765625,1280.1259765625,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:42,0.2.1 +cross_entropy,huggingface,full,memory,MB,V,vocab size,8192,2560.1259765625,2560.1259765625,2560.1259765625,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:42,0.2.1 +cross_entropy,huggingface,full,memory,MB,V,vocab size,16384,5120.1259765625,5120.1259765625,5120.1259765625,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:42,0.2.1 +cross_entropy,huggingface,full,memory,MB,V,vocab size,32768,10240.1259765625,10240.1259765625,10240.1259765625,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:42,0.2.1 +cross_entropy,huggingface,full,memory,MB,V,vocab size,65536,20480.125,20480.125,20480.125,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:42,0.2.1 +cross_entropy,huggingface,full,memory,MB,V,vocab size,131072,40960.125,40960.125,40960.125,"{""B"": 8, ""T"": 2048}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:42,0.2.1 +embedding,liger,forward,speed,ms,V,embedding dimension,1024,0.04262400045990944,0.04214400053024292,0.04428799822926521,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:53,0.2.1 +embedding,liger,forward,speed,ms,V,embedding dimension,2048,0.04668800160288811,0.04560000076889992,0.04825599864125252,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:53,0.2.1 +embedding,liger,forward,speed,ms,V,embedding dimension,4096,0.0493599995970726,0.048153601586818695,0.05084799975156784,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:53,0.2.1 +embedding,liger,forward,speed,ms,V,embedding dimension,8192,0.05558399856090546,0.054207999259233475,0.0568000003695488,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:53,0.2.1 +embedding,liger,forward,speed,ms,V,embedding dimension,16384,0.061503998935222626,0.06022400036454201,0.06260479986667633,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:53,0.2.1 +embedding,liger,forward,speed,ms,V,embedding dimension,32768,0.06518399715423584,0.06406400352716446,0.06634879857301712,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:53,0.2.1 +embedding,liger,forward,speed,ms,V,embedding dimension,65536,0.06779199838638306,0.06656000018119812,0.06905599683523178,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:53,0.2.1 +embedding,liger,forward,speed,ms,V,embedding dimension,131072,0.07091200351715088,0.06963200122117996,0.07225599884986877,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:53,0.2.1 +embedding,huggingface,forward,speed,ms,V,embedding dimension,1024,0.16672000288963318,0.1416832059621811,0.16777600347995758,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:56,0.2.1 +embedding,huggingface,forward,speed,ms,V,embedding dimension,2048,0.14406399428844452,0.1435839980840683,0.1446399986743927,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:56,0.2.1 +embedding,huggingface,forward,speed,ms,V,embedding dimension,4096,0.1539199948310852,0.15334400534629822,0.1546431928873062,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:56,0.2.1 +embedding,huggingface,forward,speed,ms,V,embedding dimension,8192,0.1627199947834015,0.16179199516773224,0.16357119381427765,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:56,0.2.1 +embedding,huggingface,forward,speed,ms,V,embedding dimension,16384,0.1666879951953888,0.16587519645690918,0.16772480309009552,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:56,0.2.1 +embedding,huggingface,forward,speed,ms,V,embedding dimension,32768,0.1687680035829544,0.16784639656543732,0.1697216033935547,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:56,0.2.1 +embedding,huggingface,forward,speed,ms,V,embedding dimension,65536,0.16918399930000305,0.1685439944267273,0.17001600563526154,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:56,0.2.1 +embedding,huggingface,forward,speed,ms,V,embedding dimension,131072,0.17027199268341064,0.16927999258041382,0.17123199999332428,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:31:56,0.2.1 +embedding,torch_compile,forward,speed,ms,V,embedding dimension,1024,0.039712000638246536,0.03798399865627289,0.04079360142350197,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:01,0.2.1 +embedding,torch_compile,forward,speed,ms,V,embedding dimension,2048,0.04652800038456917,0.045318398624658585,0.04755200073122978,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:01,0.2.1 +embedding,torch_compile,forward,speed,ms,V,embedding dimension,4096,0.05462399870157242,0.05361919850111008,0.05580800026655197,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:01,0.2.1 +embedding,torch_compile,forward,speed,ms,V,embedding dimension,8192,0.06015999987721443,0.059487998485565186,0.06102399900555611,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:01,0.2.1 +embedding,torch_compile,forward,speed,ms,V,embedding dimension,16384,0.06412799656391144,0.06329599767923355,0.06508159637451172,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:01,0.2.1 +embedding,torch_compile,forward,speed,ms,V,embedding dimension,32768,0.066880002617836,0.06583040207624435,0.06777600198984146,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:01,0.2.1 +embedding,torch_compile,forward,speed,ms,V,embedding dimension,65536,0.06896000355482101,0.06785280257463455,0.07009919732809067,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:01,0.2.1 +embedding,torch_compile,forward,speed,ms,V,embedding dimension,131072,0.06915199756622314,0.0682239979505539,0.06998399645090103,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:01,0.2.1 +embedding,liger,full,speed,ms,V,embedding dimension,1024,0.44515201449394226,0.4440639913082123,0.4463231861591339,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:05,0.2.1 +embedding,liger,full,speed,ms,V,embedding dimension,2048,0.4620960056781769,0.4610239863395691,0.46300798654556274,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:05,0.2.1 +embedding,liger,full,speed,ms,V,embedding dimension,4096,0.49136000871658325,0.4905087947845459,0.49270400404930115,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:05,0.2.1 +embedding,liger,full,speed,ms,V,embedding dimension,8192,0.5527999997138977,0.5520448088645935,0.5538623929023743,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:05,0.2.1 +embedding,liger,full,speed,ms,V,embedding dimension,16384,0.6350079774856567,0.6340479850769043,0.6363840103149414,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:05,0.2.1 +embedding,liger,full,speed,ms,V,embedding dimension,32768,0.7710559964179993,0.7691839933395386,0.7727680206298828,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:05,0.2.1 +embedding,liger,full,speed,ms,V,embedding dimension,65536,1.002560019493103,1.0006400346755981,1.004467248916626,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:05,0.2.1 +embedding,liger,full,speed,ms,V,embedding dimension,131072,1.4482879638671875,1.4459072351455688,1.4513407945632935,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:05,0.2.1 +embedding,huggingface,full,speed,ms,V,embedding dimension,1024,0.4537919759750366,0.4517247974872589,0.46081918478012085,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:08,0.2.1 +embedding,huggingface,full,speed,ms,V,embedding dimension,2048,0.47407999634742737,0.4729023873806,0.47523200511932373,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:08,0.2.1 +embedding,huggingface,full,speed,ms,V,embedding dimension,4096,0.5310080051422119,0.5298879742622375,0.5320383906364441,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:08,0.2.1 +embedding,huggingface,full,speed,ms,V,embedding dimension,8192,0.6528639793395996,0.6514303684234619,0.6546239852905273,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:08,0.2.1 +embedding,huggingface,full,speed,ms,V,embedding dimension,16384,0.8056960105895996,0.8048319816589355,0.807424008846283,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:08,0.2.1 +embedding,huggingface,full,speed,ms,V,embedding dimension,32768,0.954543948173523,0.9533119797706604,0.9559999704360962,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:08,0.2.1 +embedding,huggingface,full,speed,ms,V,embedding dimension,65536,1.1960480213165283,1.1946111917495728,1.1982656717300415,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:08,0.2.1 +embedding,huggingface,full,speed,ms,V,embedding dimension,131072,1.642624020576477,1.6409599781036377,1.6447807550430298,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:08,0.2.1 +embedding,torch_compile,full,speed,ms,V,embedding dimension,1024,0.3001280128955841,0.29503998160362244,0.30576640367507935,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:13,0.2.1 +embedding,torch_compile,full,speed,ms,V,embedding dimension,2048,0.297760009765625,0.2938239872455597,0.3054080009460449,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:13,0.2.1 +embedding,torch_compile,full,speed,ms,V,embedding dimension,4096,0.2991679906845093,0.2956480085849762,0.3070079982280731,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:13,0.2.1 +embedding,torch_compile,full,speed,ms,V,embedding dimension,8192,0.2961280047893524,0.2899264097213745,0.3029248118400574,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:13,0.2.1 +embedding,torch_compile,full,speed,ms,V,embedding dimension,16384,0.3465920090675354,0.34563198685646057,0.3476351797580719,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:13,0.2.1 +embedding,torch_compile,full,speed,ms,V,embedding dimension,32768,0.46585598587989807,0.4641471803188324,0.4674175977706909,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:13,0.2.1 +embedding,torch_compile,full,speed,ms,V,embedding dimension,65536,0.6924160122871399,0.6907200217247009,0.6938239932060242,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:13,0.2.1 +embedding,torch_compile,full,speed,ms,V,embedding dimension,131072,1.1352640390396118,1.1327999830245972,1.1376447677612305,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:13,0.2.1 +embedding,liger,forward,speed,ms,V,embedding dimension,1024,0.18961599469184875,0.1879040002822876,0.19174399971961975,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:28,0.2.1 +embedding,liger,forward,speed,ms,V,embedding dimension,2048,0.21296000480651855,0.2112639993429184,0.21513600647449493,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:28,0.2.1 +embedding,liger,forward,speed,ms,V,embedding dimension,4096,0.2367040067911148,0.23467519879341125,0.23888640105724335,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:28,0.2.1 +embedding,liger,forward,speed,ms,V,embedding dimension,8192,0.26335999369621277,0.26099199056625366,0.2656640112400055,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:28,0.2.1 +embedding,liger,forward,speed,ms,V,embedding dimension,16384,0.2850880026817322,0.28336000442504883,0.2869440019130707,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:28,0.2.1 +embedding,liger,forward,speed,ms,V,embedding dimension,32768,0.30460798740386963,0.3023360073566437,0.30684158205986023,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:28,0.2.1 +embedding,liger,forward,speed,ms,V,embedding dimension,65536,0.31569600105285645,0.3138048052787781,0.3180544078350067,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:28,0.2.1 +embedding,liger,forward,speed,ms,V,embedding dimension,131072,0.31988799571990967,0.31808000802993774,0.3219392001628876,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:28,0.2.1 +embedding,huggingface,forward,speed,ms,V,embedding dimension,1024,0.7865599989891052,0.7846271991729736,0.7891008257865906,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:43,0.2.1 +embedding,huggingface,forward,speed,ms,V,embedding dimension,2048,0.8262079954147339,0.8236607909202576,0.8279871940612793,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:43,0.2.1 +embedding,huggingface,forward,speed,ms,V,embedding dimension,4096,0.8446240425109863,0.8429504036903381,0.8475391864776611,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:43,0.2.1 +embedding,huggingface,forward,speed,ms,V,embedding dimension,8192,0.8540480136871338,0.8518400192260742,0.8557760119438171,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:43,0.2.1 +embedding,huggingface,forward,speed,ms,V,embedding dimension,16384,0.857695996761322,0.8553280234336853,0.8595200181007385,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:43,0.2.1 +embedding,huggingface,forward,speed,ms,V,embedding dimension,32768,0.8596479892730713,0.8576639890670776,0.8618879914283752,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:43,0.2.1 +embedding,huggingface,forward,speed,ms,V,embedding dimension,65536,1.0087039470672607,0.8624832034111023,1.0126848220825195,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:43,0.2.1 +embedding,huggingface,forward,speed,ms,V,embedding dimension,131072,0.8633919954299927,0.8609600067138672,0.8647680282592773,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:43,0.2.1 +embedding,torch_compile,forward,speed,ms,V,embedding dimension,1024,0.2572160065174103,0.255840003490448,0.25833600759506226,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:58,0.2.1 +embedding,torch_compile,forward,speed,ms,V,embedding dimension,2048,0.2817760109901428,0.2805440127849579,0.2831552028656006,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:58,0.2.1 +embedding,torch_compile,forward,speed,ms,V,embedding dimension,4096,0.30182400345802307,0.3002175986766815,0.3032831847667694,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:58,0.2.1 +embedding,torch_compile,forward,speed,ms,V,embedding dimension,8192,0.3126400113105774,0.3114303946495056,0.31427839398384094,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:58,0.2.1 +embedding,torch_compile,forward,speed,ms,V,embedding dimension,16384,0.3190400004386902,0.31795841455459595,0.32016000151634216,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:58,0.2.1 +embedding,torch_compile,forward,speed,ms,V,embedding dimension,32768,0.32419198751449585,0.32281601428985596,0.32559359073638916,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:58,0.2.1 +embedding,torch_compile,forward,speed,ms,V,embedding dimension,65536,0.3238080143928528,0.32236799597740173,0.3250240087509155,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:58,0.2.1 +embedding,torch_compile,forward,speed,ms,V,embedding dimension,131072,0.3256959915161133,0.32434558868408203,0.32689279317855835,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:32:58,0.2.1 +embedding,liger,full,speed,ms,V,embedding dimension,1024,2.17740797996521,2.1755776405334473,2.180025577545166,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:13,0.2.1 +embedding,liger,full,speed,ms,V,embedding dimension,2048,2.2861440181732178,2.284735918045044,2.2882239818573,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:13,0.2.1 +embedding,liger,full,speed,ms,V,embedding dimension,4096,2.4825921058654785,2.48024320602417,2.484800100326538,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:13,0.2.1 +embedding,liger,full,speed,ms,V,embedding dimension,8192,2.74452805519104,2.7430784702301025,2.7452287673950195,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:13,0.2.1 +embedding,liger,full,speed,ms,V,embedding dimension,16384,3.1216320991516113,3.1202433109283447,3.125638484954834,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:13,0.2.1 +embedding,liger,full,speed,ms,V,embedding dimension,32768,3.7801599502563477,3.774118423461914,3.7824511528015137,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:13,0.2.1 +embedding,liger,full,speed,ms,V,embedding dimension,65536,4.991136074066162,4.9875006675720215,4.993491172790527,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:13,0.2.1 +embedding,liger,full,speed,ms,V,embedding dimension,131072,7.383471965789795,7.377497673034668,7.386828899383545,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:13,0.2.1 +embedding,huggingface,full,speed,ms,V,embedding dimension,1024,1.5774879455566406,1.5668543577194214,1.7933248281478882,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:28,0.2.1 +embedding,huggingface,full,speed,ms,V,embedding dimension,2048,1.7074079513549805,1.7012799978256226,1.8109056949615479,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:28,0.2.1 +embedding,huggingface,full,speed,ms,V,embedding dimension,4096,1.950543999671936,1.9466559886932373,1.9592640399932861,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:28,0.2.1 +embedding,huggingface,full,speed,ms,V,embedding dimension,8192,2.404927968978882,2.400460720062256,2.4551360607147217,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:28,0.2.1 +embedding,huggingface,full,speed,ms,V,embedding dimension,16384,3.119904041290283,3.1171774864196777,3.1267263889312744,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:28,0.2.1 +embedding,huggingface,full,speed,ms,V,embedding dimension,32768,4.32857608795166,4.321491241455078,4.439519882202148,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:28,0.2.1 +embedding,huggingface,full,speed,ms,V,embedding dimension,65536,5.065216064453125,5.059558391571045,5.115980625152588,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:28,0.2.1 +embedding,huggingface,full,speed,ms,V,embedding dimension,131072,7.489376068115234,7.484294414520264,7.5203776359558105,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:28,0.2.1 +embedding,torch_compile,full,speed,ms,V,embedding dimension,1024,1.0930559635162354,1.0918079614639282,1.0945919752120972,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:43,0.2.1 +embedding,torch_compile,full,speed,ms,V,embedding dimension,2048,1.1930559873580933,1.191705584526062,1.1951104402542114,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:43,0.2.1 +embedding,torch_compile,full,speed,ms,V,embedding dimension,4096,1.3096319437026978,1.3073855638504028,1.3119615316390991,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:43,0.2.1 +embedding,torch_compile,full,speed,ms,V,embedding dimension,8192,1.4822720289230347,1.480512022972107,1.4839999675750732,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:43,0.2.1 +embedding,torch_compile,full,speed,ms,V,embedding dimension,16384,1.7870559692382812,1.7859647274017334,1.7892736196517944,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:43,0.2.1 +embedding,torch_compile,full,speed,ms,V,embedding dimension,32768,2.3838400840759277,2.381312131881714,2.3860929012298584,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:43,0.2.1 +embedding,torch_compile,full,speed,ms,V,embedding dimension,65536,3.7430078983306885,3.740166425704956,3.745452880859375,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:43,0.2.1 +embedding,torch_compile,full,speed,ms,V,embedding dimension,131072,5.940896034240723,5.934713363647461,5.943462371826172,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:43,0.2.1 +embedding,liger,full,memory,MB,V,embedding dimension,1024,12348.125,12348.125,12348.125,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:45,0.2.1 +embedding,liger,full,memory,MB,V,embedding dimension,2048,12360.125,12360.125,12360.125,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:45,0.2.1 +embedding,liger,full,memory,MB,V,embedding dimension,4096,12384.125,12384.125,12384.125,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:45,0.2.1 +embedding,liger,full,memory,MB,V,embedding dimension,8192,12432.125,12432.125,12432.125,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:45,0.2.1 +embedding,liger,full,memory,MB,V,embedding dimension,16384,12528.125,12528.125,12528.125,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:45,0.2.1 +embedding,liger,full,memory,MB,V,embedding dimension,32768,12720.125,12720.125,12720.125,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:45,0.2.1 +embedding,liger,full,memory,MB,V,embedding dimension,65536,13104.125,13104.125,13104.125,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:45,0.2.1 +embedding,liger,full,memory,MB,V,embedding dimension,131072,13872.125,13872.125,13872.125,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:45,0.2.1 +embedding,huggingface,full,memory,MB,V,embedding dimension,1024,12356.537109375,12356.537109375,12356.537109375,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:48,0.2.1 +embedding,huggingface,full,memory,MB,V,embedding dimension,2048,12371.359375,12371.359375,12371.359375,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:48,0.2.1 +embedding,huggingface,full,memory,MB,V,embedding dimension,4096,12401.40625,12401.40625,12401.40625,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:48,0.2.1 +embedding,huggingface,full,memory,MB,V,embedding dimension,8192,12461.5,12461.5,12461.5,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:48,0.2.1 +embedding,huggingface,full,memory,MB,V,embedding dimension,16384,12581.6875,12581.6875,12581.6875,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:48,0.2.1 +embedding,huggingface,full,memory,MB,V,embedding dimension,32768,12773.6875,12773.6875,12773.6875,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:48,0.2.1 +embedding,huggingface,full,memory,MB,V,embedding dimension,65536,13157.6875,13157.6875,13157.6875,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:48,0.2.1 +embedding,huggingface,full,memory,MB,V,embedding dimension,131072,13925.6875,13925.6875,13925.6875,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:48,0.2.1 +embedding,torch_compile,full,memory,MB,V,embedding dimension,1024,12348.125,12348.125,12348.125,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:52,0.2.1 +embedding,torch_compile,full,memory,MB,V,embedding dimension,2048,12366.125,12366.125,12366.125,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:52,0.2.1 +embedding,torch_compile,full,memory,MB,V,embedding dimension,4096,12402.125,12402.125,12402.125,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:52,0.2.1 +embedding,torch_compile,full,memory,MB,V,embedding dimension,8192,12474.125,12474.125,12474.125,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:52,0.2.1 +embedding,torch_compile,full,memory,MB,V,embedding dimension,16384,12618.125,12618.125,12618.125,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:52,0.2.1 +embedding,torch_compile,full,memory,MB,V,embedding dimension,32768,12906.125,12906.125,12906.125,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:52,0.2.1 +embedding,torch_compile,full,memory,MB,V,embedding dimension,65536,13482.125,13482.125,13482.125,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:52,0.2.1 +embedding,torch_compile,full,memory,MB,V,embedding dimension,131072,14634.125,14634.125,14634.125,"{""B"": 32, ""T"": 512, ""D"": 768, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:33:52,0.2.1 +embedding,liger,full,memory,MB,V,embedding dimension,1024,14346.125,14346.125,14346.125,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:04,0.2.1 +embedding,liger,full,memory,MB,V,embedding dimension,2048,14410.125,14410.125,14410.125,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:04,0.2.1 +embedding,liger,full,memory,MB,V,embedding dimension,4096,14538.125,14538.125,14538.125,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:04,0.2.1 +embedding,liger,full,memory,MB,V,embedding dimension,8192,14794.125,14794.125,14794.125,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:04,0.2.1 +embedding,liger,full,memory,MB,V,embedding dimension,16384,15306.125,15306.125,15306.125,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:04,0.2.1 +embedding,liger,full,memory,MB,V,embedding dimension,32768,16330.125,16330.125,16330.125,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:04,0.2.1 +embedding,liger,full,memory,MB,V,embedding dimension,65536,18378.125,18378.125,18378.125,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:04,0.2.1 +embedding,liger,full,memory,MB,V,embedding dimension,131072,22474.125,22474.125,22474.125,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:04,0.2.1 +embedding,huggingface,full,memory,MB,V,embedding dimension,1024,14388.130859375,14388.130859375,14388.130859375,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:17,0.2.1 +embedding,huggingface,full,memory,MB,V,embedding dimension,2048,14468.154296875,14468.154296875,14468.154296875,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:17,0.2.1 +embedding,huggingface,full,memory,MB,V,embedding dimension,4096,14628.201171875,14628.201171875,14628.201171875,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:17,0.2.1 +embedding,huggingface,full,memory,MB,V,embedding dimension,8192,14948.294921875,14948.294921875,14948.294921875,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:17,0.2.1 +embedding,huggingface,full,memory,MB,V,embedding dimension,16384,15588.482421875,15588.482421875,15588.482421875,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:17,0.2.1 +embedding,huggingface,full,memory,MB,V,embedding dimension,32768,16612.482421875,16612.482421875,16612.482421875,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:17,0.2.1 +embedding,huggingface,full,memory,MB,V,embedding dimension,65536,18660.482421875,18660.482421875,18660.482421875,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:17,0.2.1 +embedding,huggingface,full,memory,MB,V,embedding dimension,131072,22756.482421875,22756.482421875,22756.482421875,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:17,0.2.1 +embedding,torch_compile,full,memory,MB,V,embedding dimension,1024,14346.125,14346.125,14346.125,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:31,0.2.1 +embedding,torch_compile,full,memory,MB,V,embedding dimension,2048,14442.125,14442.125,14442.125,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:31,0.2.1 +embedding,torch_compile,full,memory,MB,V,embedding dimension,4096,14634.125,14634.125,14634.125,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:31,0.2.1 +embedding,torch_compile,full,memory,MB,V,embedding dimension,8192,15018.125,15018.125,15018.125,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:31,0.2.1 +embedding,torch_compile,full,memory,MB,V,embedding dimension,16384,1536.125,1536.125,1536.125,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:31,0.2.1 +embedding,torch_compile,full,memory,MB,V,embedding dimension,32768,3072.125,3072.125,3072.125,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:31,0.2.1 +embedding,torch_compile,full,memory,MB,V,embedding dimension,65536,6144.125,6144.125,6144.125,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:31,0.2.1 +embedding,torch_compile,full,memory,MB,V,embedding dimension,131072,12288.125,12288.125,12288.125,"{""B"": 8, ""T"": 2048, ""D"": 4096, ""dtype"": ""torch.float32""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:34:31,0.2.1 +fused_linear_cross_entropy,liger,forward,speed,ms,BT,B x T,4096,119.52153778076172,119.52153778076172,119.52153778076172,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:44:03,0.4.2 +fused_linear_cross_entropy,liger,forward,speed,ms,BT,B x T,8192,168.08563232421875,168.08563232421875,168.08563232421875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:44:03,0.4.2 +fused_linear_cross_entropy,liger,forward,speed,ms,BT,B x T,16384,274.07342529296875,274.07342529296875,274.07342529296875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:44:03,0.4.2 +fused_linear_cross_entropy,liger,forward,speed,ms,BT,B x T,32768,508.4652099609375,508.4652099609375,508.4652099609375,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:44:03,0.4.2 +fused_linear_cross_entropy,huggingface,forward,speed,ms,BT,B x T,4096,20.911680221557617,20.90903663635254,20.915321350097656,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:44:34,0.4.2 +fused_linear_cross_entropy,huggingface,forward,speed,ms,BT,B x T,8192,37.97203063964844,37.9546012878418,37.989463806152344,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:44:34,0.4.2 +fused_linear_cross_entropy,huggingface,forward,speed,ms,BT,B x T,16384,76.39142608642578,76.39142608642578,76.39142608642578,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:44:34,0.4.2 +fused_linear_cross_entropy,huggingface,forward,speed,ms,BT,B x T,32768,151.91404724121094,151.91404724121094,151.91404724121094,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:44:34,0.4.2 +fused_linear_cross_entropy,liger,full,speed,ms,BT,B x T,4096,121.43059539794922,121.43059539794922,121.43059539794922,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:45:11,0.4.2 +fused_linear_cross_entropy,liger,full,speed,ms,BT,B x T,8192,166.70867919921875,166.70867919921875,166.70867919921875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:45:11,0.4.2 +fused_linear_cross_entropy,liger,full,speed,ms,BT,B x T,16384,277.1166687011719,277.1166687011719,277.1166687011719,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:45:11,0.4.2 +fused_linear_cross_entropy,liger,full,speed,ms,BT,B x T,32768,511.0638732910156,511.0638732910156,511.0638732910156,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:45:11,0.4.2 +fused_linear_cross_entropy,huggingface,full,speed,ms,BT,B x T,4096,55.96684646606445,55.96684646606445,55.96684646606445,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:45:46,0.4.2 +fused_linear_cross_entropy,huggingface,full,speed,ms,BT,B x T,8192,111.45471954345703,111.45471954345703,111.45471954345703,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:45:46,0.4.2 +fused_linear_cross_entropy,huggingface,full,speed,ms,BT,B x T,16384,220.7836151123047,220.7836151123047,220.7836151123047,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:45:46,0.4.2 +fused_linear_cross_entropy,huggingface,full,speed,ms,BT,B x T,32768,452.4712829589844,452.4712829589844,452.4712829589844,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:45:46,0.4.2 +fused_linear_cross_entropy,liger,full,memory,MB,BT,B x T,4096,4245.5478515625,4245.5478515625,4245.5478515625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:46:25,0.4.2 +fused_linear_cross_entropy,liger,full,memory,MB,BT,B x T,8192,4466.9697265625,4466.9697265625,4466.9697265625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:46:25,0.4.2 +fused_linear_cross_entropy,liger,full,memory,MB,BT,B x T,16384,4910.4384765625,4910.4384765625,4910.4384765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:46:25,0.4.2 +fused_linear_cross_entropy,liger,full,memory,MB,BT,B x T,32768,5794.6259765625,5794.6259765625,5794.6259765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:46:25,0.4.2 +fused_linear_cross_entropy,huggingface,full,memory,MB,BT,B x T,4096,6092.2822265625,6092.2822265625,6092.2822265625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:46:53,0.4.2 +fused_linear_cross_entropy,huggingface,full,memory,MB,BT,B x T,8192,9162.3134765625,9162.3134765625,9162.3134765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:46:53,0.4.2 +fused_linear_cross_entropy,huggingface,full,memory,MB,BT,B x T,16384,15302.3759765625,15302.3759765625,15302.3759765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:46:53,0.4.2 +fused_linear_cross_entropy,huggingface,full,memory,MB,BT,B x T,32768,27582.5,27582.5,27582.5,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-22 17:46:53,0.4.2 +geglu,liger,full,speed,ms,T,sequence length,1024,30.03536033630371,30.03536033630371,30.03536033630371,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:14,0.2.1 +geglu,liger,full,speed,ms,T,sequence length,2048,54.04060745239258,54.04060745239258,54.04060745239258,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:14,0.2.1 +geglu,liger,full,speed,ms,T,sequence length,4096,108.52435302734375,108.52435302734375,108.52435302734375,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:14,0.2.1 +geglu,liger,full,speed,ms,T,sequence length,8192,216.6227264404297,216.6227264404297,216.6227264404297,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:14,0.2.1 +geglu,huggingface,full,speed,ms,T,sequence length,1024,27.938560485839844,27.938560485839844,27.938560485839844,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:21,0.2.1 +geglu,huggingface,full,speed,ms,T,sequence length,2048,54.51279830932617,54.51279830932617,54.51279830932617,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:21,0.2.1 +geglu,huggingface,full,speed,ms,T,sequence length,4096,110.97718048095703,110.97718048095703,110.97718048095703,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:21,0.2.1 +geglu,huggingface,full,speed,ms,T,sequence length,8192,220.93954467773438,220.93954467773438,220.93954467773438,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:21,0.2.1 +geglu,liger,forward,speed,ms,T,sequence length,1024,9.280096054077148,9.280096054077148,9.280096054077148,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:26,0.2.1 +geglu,liger,forward,speed,ms,T,sequence length,2048,17.59040069580078,17.59040069580078,17.59040069580078,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:26,0.2.1 +geglu,liger,forward,speed,ms,T,sequence length,4096,36.18726348876953,36.18726348876953,36.18726348876953,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:26,0.2.1 +geglu,liger,forward,speed,ms,T,sequence length,8192,72.60655975341797,72.60655975341797,72.60655975341797,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:26,0.2.1 +geglu,huggingface,forward,speed,ms,T,sequence length,1024,9.257439613342285,9.257439613342285,9.257439613342285,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:31,0.2.1 +geglu,huggingface,forward,speed,ms,T,sequence length,2048,18.099519729614258,18.099519729614258,18.099519729614258,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:31,0.2.1 +geglu,huggingface,forward,speed,ms,T,sequence length,4096,36.37263870239258,36.37263870239258,36.37263870239258,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:31,0.2.1 +geglu,huggingface,forward,speed,ms,T,sequence length,8192,72.66553497314453,72.66553497314453,72.66553497314453,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:31,0.2.1 +geglu,liger,backward,speed,ms,T,sequence length,1024,18.088287353515625,18.088287353515625,18.088287353515625,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:37,0.2.1 +geglu,liger,backward,speed,ms,T,sequence length,2048,35.195518493652344,35.195518493652344,35.195518493652344,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:37,0.2.1 +geglu,liger,backward,speed,ms,T,sequence length,4096,70.51395416259766,70.51395416259766,70.51395416259766,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:37,0.2.1 +geglu,liger,backward,speed,ms,T,sequence length,8192,141.28550720214844,141.28550720214844,141.28550720214844,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:37,0.2.1 +geglu,huggingface,backward,speed,ms,T,sequence length,1024,18.521728515625,18.521728515625,18.521728515625,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:42,0.2.1 +geglu,huggingface,backward,speed,ms,T,sequence length,2048,36.045406341552734,36.045406341552734,36.045406341552734,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:42,0.2.1 +geglu,huggingface,backward,speed,ms,T,sequence length,4096,72.88412475585938,72.88412475585938,72.88412475585938,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:42,0.2.1 +geglu,huggingface,backward,speed,ms,T,sequence length,8192,144.2132110595703,144.2132110595703,144.2132110595703,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:42,0.2.1 +geglu,liger,full,memory,MB,T,sequence length,1024,1582.25,1582.25,1582.25,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:45,0.2.1 +geglu,liger,full,memory,MB,T,sequence length,2048,2546.25,2546.25,2546.25,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:45,0.2.1 +geglu,liger,full,memory,MB,T,sequence length,4096,4474.25,4474.25,4474.25,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:45,0.2.1 +geglu,liger,full,memory,MB,T,sequence length,8192,8330.25,8330.25,8330.25,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:45,0.2.1 +geglu,huggingface,full,memory,MB,T,sequence length,1024,1992.25,1992.25,1992.25,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:50,0.2.1 +geglu,huggingface,full,memory,MB,T,sequence length,2048,3452.25,3452.25,3452.25,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:50,0.2.1 +geglu,huggingface,full,memory,MB,T,sequence length,4096,6372.25,6372.25,6372.25,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:50,0.2.1 +geglu,huggingface,full,memory,MB,T,sequence length,8192,12212.25,12212.25,12212.25,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:50,0.2.1 +geglu,liger,forward,memory,MB,T,sequence length,1024,918.25,918.25,918.25,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:55,0.2.1 +geglu,liger,forward,memory,MB,T,sequence length,2048,1562.25,1562.25,1562.25,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:55,0.2.1 +geglu,liger,forward,memory,MB,T,sequence length,4096,2850.25,2850.25,2850.25,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:55,0.2.1 +geglu,liger,forward,memory,MB,T,sequence length,8192,5426.25,5426.25,5426.25,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:55,0.2.1 +geglu,huggingface,forward,memory,MB,T,sequence length,1024,1090.25,1090.25,1090.25,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:58,0.2.1 +geglu,huggingface,forward,memory,MB,T,sequence length,2048,1906.25,1906.25,1906.25,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:58,0.2.1 +geglu,huggingface,forward,memory,MB,T,sequence length,4096,3538.25,3538.25,3538.25,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:58,0.2.1 +geglu,huggingface,forward,memory,MB,T,sequence length,8192,6802.25,6802.25,6802.25,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:38:58,0.2.1 +geglu,liger,backward,memory,MB,T,sequence length,1024,1582.25,1582.25,1582.25,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:02,0.2.1 +geglu,liger,backward,memory,MB,T,sequence length,2048,2546.25,2546.25,2546.25,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:02,0.2.1 +geglu,liger,backward,memory,MB,T,sequence length,4096,4474.25,4474.25,4474.25,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:02,0.2.1 +geglu,liger,backward,memory,MB,T,sequence length,8192,8330.25,8330.25,8330.25,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:02,0.2.1 +geglu,huggingface,backward,memory,MB,T,sequence length,1024,1992.25,1992.25,1992.25,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:06,0.2.1 +geglu,huggingface,backward,memory,MB,T,sequence length,2048,3452.25,3452.25,3452.25,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:06,0.2.1 +geglu,huggingface,backward,memory,MB,T,sequence length,4096,6372.25,6372.25,6372.25,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:06,0.2.1 +geglu,huggingface,backward,memory,MB,T,sequence length,8192,12212.25,12212.25,12212.25,"{""bsz"": 8, ""hidden_size"": 4096, ""intermediate_size"": 11008, ""hidden_act"": ""gelu_pytorch_tanh"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:06,0.2.1 +layer_norm,liger,forward,speed,ms,N,hidden size,1024,0.030271999537944794,0.02921600081026554,0.03142400085926056,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:14,0.2.1 +layer_norm,liger,forward,speed,ms,N,hidden size,2048,0.04992000013589859,0.04912000149488449,0.050783999264240265,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:14,0.2.1 +layer_norm,liger,forward,speed,ms,N,hidden size,4096,0.08816000074148178,0.08739200234413147,0.08899199962615967,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:14,0.2.1 +layer_norm,liger,forward,speed,ms,N,hidden size,8192,0.16521599888801575,0.16435199975967407,0.16627199947834015,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:14,0.2.1 +layer_norm,liger,forward,speed,ms,N,hidden size,16384,0.32230401039123535,0.32070401310920715,0.32393598556518555,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:14,0.2.1 +layer_norm,huggingface,forward,speed,ms,N,hidden size,1024,0.034143999218940735,0.033376000821590424,0.03580800071358681,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:17,0.2.1 +layer_norm,huggingface,forward,speed,ms,N,hidden size,2048,0.05734400078654289,0.05615999922156334,0.05859199911355972,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:17,0.2.1 +layer_norm,huggingface,forward,speed,ms,N,hidden size,4096,0.1218239963054657,0.12054400146007538,0.12316799908876419,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:17,0.2.1 +layer_norm,huggingface,forward,speed,ms,N,hidden size,8192,0.25755199790000916,0.255840003490448,0.25939199328422546,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:17,0.2.1 +layer_norm,huggingface,forward,speed,ms,N,hidden size,16384,0.5066879987716675,0.5045183897018433,0.5089280009269714,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:17,0.2.1 +layer_norm,liger,full,speed,ms,N,hidden size,1024,0.28019198775291443,0.2780799865722656,0.284960001707077,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:19,0.2.1 +layer_norm,liger,full,speed,ms,N,hidden size,2048,0.27827200293540955,0.27638399600982666,0.2824704051017761,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:19,0.2.1 +layer_norm,liger,full,speed,ms,N,hidden size,4096,0.2847039997577667,0.27955201268196106,0.2908479869365692,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:19,0.2.1 +layer_norm,liger,full,speed,ms,N,hidden size,8192,0.4405759871006012,0.43780481815338135,0.4440320134162903,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:19,0.2.1 +layer_norm,liger,full,speed,ms,N,hidden size,16384,1.1488319635391235,1.1439871788024902,1.1527807712554932,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:19,0.2.1 +layer_norm,huggingface,full,speed,ms,N,hidden size,1024,0.11884800344705582,0.11750400066375732,0.12035199999809265,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:21,0.2.1 +layer_norm,huggingface,full,speed,ms,N,hidden size,2048,0.1966399997472763,0.19432319700717926,0.19888000190258026,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:21,0.2.1 +layer_norm,huggingface,full,speed,ms,N,hidden size,4096,0.43142399191856384,0.42931199073791504,0.4336639940738678,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:21,0.2.1 +layer_norm,huggingface,full,speed,ms,N,hidden size,8192,0.829584002494812,0.826918363571167,0.832857608795166,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:21,0.2.1 +layer_norm,huggingface,full,speed,ms,N,hidden size,16384,1.6212799549102783,1.6171647310256958,1.6246912479400635,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:21,0.2.1 +layer_norm,liger,full,memory,MB,N,hidden size,1024,80.90625,80.90625,80.90625,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:21,0.2.1 +layer_norm,liger,full,memory,MB,N,hidden size,2048,161.78125,161.78125,161.78125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:21,0.2.1 +layer_norm,liger,full,memory,MB,N,hidden size,4096,323.53125,323.53125,323.53125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:21,0.2.1 +layer_norm,liger,full,memory,MB,N,hidden size,8192,647.03125,647.03125,647.03125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:21,0.2.1 +layer_norm,liger,full,memory,MB,N,hidden size,16384,1294.03125,1294.03125,1294.03125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:21,0.2.1 +layer_norm,huggingface,full,memory,MB,N,hidden size,1024,80.0625,80.0625,80.0625,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:21,0.2.1 +layer_norm,huggingface,full,memory,MB,N,hidden size,2048,160.09375,160.09375,160.09375,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:21,0.2.1 +layer_norm,huggingface,full,memory,MB,N,hidden size,4096,320.15625,320.15625,320.15625,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:21,0.2.1 +layer_norm,huggingface,full,memory,MB,N,hidden size,8192,640.28125,640.28125,640.28125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:21,0.2.1 +layer_norm,huggingface,full,memory,MB,N,hidden size,16384,1280.53125,1280.53125,1280.53125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:21,0.2.1 +rms_norm,liger,forward,speed,ms,H,hidden size,1024,0.01360000018030405,0.012864000163972378,0.01603199914097786,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:30,0.2.1 +rms_norm,liger,forward,speed,ms,H,hidden size,2048,0.019999999552965164,0.018624000251293182,0.02160000056028366,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:30,0.2.1 +rms_norm,liger,forward,speed,ms,H,hidden size,4096,0.031072000041604042,0.030047999694943428,0.031968001276254654,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:30,0.2.1 +rms_norm,liger,forward,speed,ms,H,hidden size,8192,0.0517439991235733,0.050624001771211624,0.05289600044488907,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:30,0.2.1 +rms_norm,liger,forward,speed,ms,H,hidden size,16384,0.0952640026807785,0.0942080020904541,0.09667199850082397,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:30,0.2.1 +rms_norm,liger,forward,speed,ms,H,hidden size,32768,0.18223999440670013,0.18035200238227844,0.18417279422283173,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:30,0.2.1 +rms_norm,huggingface,forward,speed,ms,H,hidden size,1024,0.07820799946784973,0.0777600035071373,0.0790719985961914,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:33,0.2.1 +rms_norm,huggingface,forward,speed,ms,H,hidden size,2048,0.13631999492645264,0.13555200397968292,0.13731199502944946,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:33,0.2.1 +rms_norm,huggingface,forward,speed,ms,H,hidden size,4096,0.27990400791168213,0.2789439857006073,0.28118398785591125,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:33,0.2.1 +rms_norm,huggingface,forward,speed,ms,H,hidden size,8192,0.5190399885177612,0.5175359845161438,0.5209856033325195,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:33,0.2.1 +rms_norm,huggingface,forward,speed,ms,H,hidden size,16384,0.9856320023536682,0.9835839867591858,0.9876928329467773,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:33,0.2.1 +rms_norm,huggingface,forward,speed,ms,H,hidden size,32768,1.9190720319747925,1.917081594467163,1.921875238418579,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:33,0.2.1 +rms_norm,liger,full,speed,ms,H,hidden size,1024,0.28601598739624023,0.2837119996547699,0.29068800806999207,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:36,0.2.1 +rms_norm,liger,full,speed,ms,H,hidden size,2048,0.286624014377594,0.2845824062824249,0.2905920147895813,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:36,0.2.1 +rms_norm,liger,full,speed,ms,H,hidden size,4096,0.28830400109291077,0.28533118963241577,0.2935168147087097,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:36,0.2.1 +rms_norm,liger,full,speed,ms,H,hidden size,8192,0.29407998919487,0.289216011762619,0.3038719892501831,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:36,0.2.1 +rms_norm,liger,full,speed,ms,H,hidden size,16384,0.410863995552063,0.4088575839996338,0.41293439269065857,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:36,0.2.1 +rms_norm,liger,full,speed,ms,H,hidden size,32768,1.2316479682922363,1.228230357170105,1.235001564025879,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:36,0.2.1 +rms_norm,huggingface,full,speed,ms,H,hidden size,1024,0.3176960051059723,0.3147839903831482,0.32177281379699707,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:40,0.2.1 +rms_norm,huggingface,full,speed,ms,H,hidden size,2048,0.49038398265838623,0.4888896048069,0.4920639991760254,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:40,0.2.1 +rms_norm,huggingface,full,speed,ms,H,hidden size,4096,1.011423945426941,1.0089855194091797,1.013759970664978,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:40,0.2.1 +rms_norm,huggingface,full,speed,ms,H,hidden size,8192,1.8621759414672852,1.859769582748413,1.8646591901779175,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:40,0.2.1 +rms_norm,huggingface,full,speed,ms,H,hidden size,16384,3.5439999103546143,3.5410239696502686,3.547679901123047,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:40,0.2.1 +rms_norm,huggingface,full,speed,ms,H,hidden size,32768,6.910431861877441,6.907142639160156,6.914393901824951,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:40,0.2.1 +rms_norm,liger,backward,speed,ms,H,hidden size,1024,0.09372799843549728,0.09177599847316742,0.09763199836015701,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:43,0.2.1 +rms_norm,liger,backward,speed,ms,H,hidden size,2048,0.09030400216579437,0.08746880292892456,0.09398400038480759,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:43,0.2.1 +rms_norm,liger,backward,speed,ms,H,hidden size,4096,0.09913600236177444,0.09804800152778625,0.10039679706096649,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:43,0.2.1 +rms_norm,liger,backward,speed,ms,H,hidden size,8192,0.17801600694656372,0.1765120029449463,0.1793919950723648,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:43,0.2.1 +rms_norm,liger,backward,speed,ms,H,hidden size,16384,0.32051199674606323,0.3187839984893799,0.32230401039123535,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:43,0.2.1 +rms_norm,liger,backward,speed,ms,H,hidden size,32768,1.0562880039215088,1.053491234779358,1.059673547744751,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:43,0.2.1 +rms_norm,huggingface,backward,speed,ms,H,hidden size,1024,0.19577600061893463,0.19523200392723083,0.19631999731063843,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:46,0.2.1 +rms_norm,huggingface,backward,speed,ms,H,hidden size,2048,0.36188799142837524,0.3601599931716919,0.363647997379303,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:46,0.2.1 +rms_norm,huggingface,backward,speed,ms,H,hidden size,4096,0.7403839826583862,0.7381759881973267,0.7426176071166992,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:46,0.2.1 +rms_norm,huggingface,backward,speed,ms,H,hidden size,8192,1.3515520095825195,1.348736047744751,1.3550655841827393,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:46,0.2.1 +rms_norm,huggingface,backward,speed,ms,H,hidden size,16384,2.569632053375244,2.5663681030273438,2.5731201171875,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:46,0.2.1 +rms_norm,huggingface,backward,speed,ms,H,hidden size,32768,5.0147199630737305,5.011123180389404,5.0179901123046875,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:46,0.2.1 +rms_norm,liger,full,memory,MB,H,hidden size,1024,36.02392578125,36.02392578125,36.02392578125,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:46,0.2.1 +rms_norm,liger,full,memory,MB,H,hidden size,2048,72.03955078125,72.03955078125,72.03955078125,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:46,0.2.1 +rms_norm,liger,full,memory,MB,H,hidden size,4096,144.07080078125,144.07080078125,144.07080078125,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:46,0.2.1 +rms_norm,liger,full,memory,MB,H,hidden size,8192,268.13330078125,268.13330078125,268.13330078125,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:46,0.2.1 +rms_norm,liger,full,memory,MB,H,hidden size,16384,432.25830078125,432.25830078125,432.25830078125,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:46,0.2.1 +rms_norm,liger,full,memory,MB,H,hidden size,32768,752.5087890625,752.5087890625,752.5087890625,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:46,0.2.1 +rms_norm,huggingface,full,memory,MB,H,hidden size,1024,80.01953125,80.01953125,80.01953125,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:46,0.2.1 +rms_norm,huggingface,full,memory,MB,H,hidden size,2048,160.03125,160.03125,160.03125,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:46,0.2.1 +rms_norm,huggingface,full,memory,MB,H,hidden size,4096,320.0546875,320.0546875,320.0546875,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:46,0.2.1 +rms_norm,huggingface,full,memory,MB,H,hidden size,8192,640.1015625,640.1015625,640.1015625,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:46,0.2.1 +rms_norm,huggingface,full,memory,MB,H,hidden size,16384,1280.1953125,1280.1953125,1280.1953125,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:46,0.2.1 +rms_norm,huggingface,full,memory,MB,H,hidden size,32768,2560.3828125,2560.3828125,2560.3828125,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:46,0.2.1 +rope,liger,forward,speed,ms,H,hidden size,512,0.011359999887645245,0.01033599954098463,0.011455999687314034,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:55,0.2.1 +rope,liger,forward,speed,ms,H,hidden size,2048,0.020864000543951988,0.020447999238967896,0.02239999920129776,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:55,0.2.1 +rope,liger,forward,speed,ms,H,hidden size,8192,0.059487998485565186,0.05830400064587593,0.06060799956321716,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:55,0.2.1 +rope,huggingface,forward,speed,ms,H,hidden size,512,0.07968000322580338,0.07923199981451035,0.10408961027860641,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:56,0.2.1 +rope,huggingface,forward,speed,ms,H,hidden size,2048,0.1570879966020584,0.15651200711727142,0.15785600244998932,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:56,0.2.1 +rope,huggingface,forward,speed,ms,H,hidden size,8192,0.5167999863624573,0.5161600112915039,0.5176640152931213,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:56,0.2.1 +rope,liger,backward,speed,ms,H,hidden size,512,0.12227199971675873,0.05539200082421303,0.1699904054403305,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:57,0.2.1 +rope,liger,backward,speed,ms,H,hidden size,2048,0.12337599694728851,0.11945600062608719,0.15338242053985596,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:57,0.2.1 +rope,liger,backward,speed,ms,H,hidden size,8192,0.12812800705432892,0.11593600362539291,0.1985855996608734,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:57,0.2.1 +rope,huggingface,backward,speed,ms,H,hidden size,512,0.2648000121116638,0.2489279955625534,0.3578239977359772,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:59,0.2.1 +rope,huggingface,backward,speed,ms,H,hidden size,2048,0.2536320090293884,0.24692480266094208,0.31929606199264526,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:59,0.2.1 +rope,huggingface,backward,speed,ms,H,hidden size,8192,0.621504008769989,0.6208000183105469,0.6223679780960083,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:39:59,0.2.1 +rope,liger,full,speed,ms,H,hidden size,512,0.27401599287986755,0.26447999477386475,0.3555007874965668,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:00,0.2.1 +rope,liger,full,speed,ms,H,hidden size,2048,0.2815040051937103,0.26904961466789246,0.3562496304512024,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:00,0.2.1 +rope,liger,full,speed,ms,H,hidden size,8192,0.2759679853916168,0.267244815826416,0.3601728081703186,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:00,0.2.1 +rope,huggingface,full,speed,ms,H,hidden size,512,0.5160639882087708,0.5028480291366577,0.6553279757499695,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:01,0.2.1 +rope,huggingface,full,speed,ms,H,hidden size,2048,0.5289119482040405,0.510598361492157,0.7208256721496582,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:01,0.2.1 +rope,huggingface,full,speed,ms,H,hidden size,8192,1.1329920291900635,1.1318720579147339,1.1339199542999268,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:01,0.2.1 +rope,liger,full,memory,MB,H,hidden size,512,13.26611328125,13.26611328125,13.26611328125,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:01,0.2.1 +rope,liger,full,memory,MB,H,hidden size,2048,28.64111328125,28.64111328125,28.64111328125,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:01,0.2.1 +rope,liger,full,memory,MB,H,hidden size,8192,90.14111328125,90.14111328125,90.14111328125,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:01,0.2.1 +rope,huggingface,full,memory,MB,H,hidden size,512,22.26611328125,22.26611328125,22.26611328125,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:01,0.2.1 +rope,huggingface,full,memory,MB,H,hidden size,2048,64.64111328125,64.64111328125,64.64111328125,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:01,0.2.1 +rope,huggingface,full,memory,MB,H,hidden size,8192,234.14111328125,234.14111328125,234.14111328125,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:01,0.2.1 +rope,liger,forward,speed,ms,T,sequence length,1024,0.034432001411914825,0.03340800106525421,0.03545600175857544,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:04,0.2.1 +rope,liger,forward,speed,ms,T,sequence length,2048,0.058880001306533813,0.0578560009598732,0.059859201312065125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:04,0.2.1 +rope,liger,forward,speed,ms,T,sequence length,4096,0.10899200290441513,0.10784000158309937,0.1101439967751503,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:04,0.2.1 +rope,liger,forward,speed,ms,T,sequence length,8192,0.20927999913692474,0.20796799659729004,0.21059200167655945,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:04,0.2.1 +rope,liger,forward,speed,ms,T,sequence length,16384,0.4105280041694641,0.4089151918888092,0.41203200817108154,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:04,0.2.1 +rope,huggingface,forward,speed,ms,T,sequence length,1024,0.2808319926261902,0.28019198775291443,0.28160640597343445,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:06,0.2.1 +rope,huggingface,forward,speed,ms,T,sequence length,2048,0.5160959959030151,0.5155072212219238,0.5169280171394348,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:06,0.2.1 +rope,huggingface,forward,speed,ms,T,sequence length,4096,0.9947839975357056,0.9939200282096863,0.9956799745559692,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:06,0.2.1 +rope,huggingface,forward,speed,ms,T,sequence length,8192,1.9332640171051025,1.9323519468307495,1.9344960451126099,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:06,0.2.1 +rope,huggingface,forward,speed,ms,T,sequence length,16384,3.8169920444488525,3.815808057785034,3.8180160522460938,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:06,0.2.1 +rope,liger,backward,speed,ms,T,sequence length,1024,0.1260479986667633,0.12014079838991165,0.143449604511261,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:08,0.2.1 +rope,liger,backward,speed,ms,T,sequence length,2048,0.11606399714946747,0.11021439731121063,0.12432000041007996,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:08,0.2.1 +rope,liger,backward,speed,ms,T,sequence length,4096,0.12409599870443344,0.11817599833011627,0.1313920021057129,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:08,0.2.1 +rope,liger,backward,speed,ms,T,sequence length,8192,0.21004800498485565,0.20867200195789337,0.21164800226688385,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:08,0.2.1 +rope,liger,backward,speed,ms,T,sequence length,16384,0.4102399945259094,0.40871042013168335,0.4119040071964264,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:08,0.2.1 +rope,huggingface,backward,speed,ms,T,sequence length,1024,0.3304319977760315,0.3296447992324829,0.3314239978790283,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:10,0.2.1 +rope,huggingface,backward,speed,ms,T,sequence length,2048,0.6213759779930115,0.6205440163612366,0.6223359704017639,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:10,0.2.1 +rope,huggingface,backward,speed,ms,T,sequence length,4096,1.1872799396514893,1.1858432292938232,1.1886080503463745,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:10,0.2.1 +rope,huggingface,backward,speed,ms,T,sequence length,8192,2.321280002593994,2.318873643875122,2.324160099029541,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:10,0.2.1 +rope,huggingface,backward,speed,ms,T,sequence length,16384,4.557248115539551,4.550220966339111,4.560742378234863,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:10,0.2.1 +rope,liger,full,speed,ms,T,sequence length,1024,0.2682560086250305,0.2641535997390747,0.2762559950351715,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:12,0.2.1 +rope,liger,full,speed,ms,T,sequence length,2048,0.2654559910297394,0.26105600595474243,0.2746559977531433,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:12,0.2.1 +rope,liger,full,speed,ms,T,sequence length,4096,0.2650560140609741,0.2608831822872162,0.2715519964694977,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:12,0.2.1 +rope,liger,full,speed,ms,T,sequence length,8192,0.4158720076084137,0.41413119435310364,0.4178048074245453,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:12,0.2.1 +rope,liger,full,speed,ms,T,sequence length,16384,0.8167039752006531,0.8143680095672607,0.8189184069633484,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:12,0.2.1 +rope,huggingface,full,speed,ms,T,sequence length,1024,0.6059200167655945,0.6047679781913757,0.6072319746017456,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:14,0.2.1 +rope,huggingface,full,speed,ms,T,sequence length,2048,1.1326719522476196,1.1318080425262451,1.133631944656372,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:14,0.2.1 +rope,huggingface,full,speed,ms,T,sequence length,4096,2.176192045211792,2.175136089324951,2.177433729171753,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:14,0.2.1 +rope,huggingface,full,speed,ms,T,sequence length,8192,4.248256206512451,4.246367931365967,4.2566399574279785,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:14,0.2.1 +rope,huggingface,full,speed,ms,T,sequence length,16384,8.365951538085938,8.36348819732666,8.380928039550781,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:14,0.2.1 +rope,liger,full,memory,MB,T,sequence length,1024,49.13330078125,49.13330078125,49.13330078125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:14,0.2.1 +rope,liger,full,memory,MB,T,sequence length,2048,90.14111328125,90.14111328125,90.14111328125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:14,0.2.1 +rope,liger,full,memory,MB,T,sequence length,4096,172.15673828125,172.15673828125,172.15673828125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:14,0.2.1 +rope,liger,full,memory,MB,T,sequence length,8192,336.18798828125,336.18798828125,336.18798828125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:14,0.2.1 +rope,liger,full,memory,MB,T,sequence length,16384,664.25048828125,664.25048828125,664.25048828125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:14,0.2.1 +rope,huggingface,full,memory,MB,T,sequence length,1024,121.13330078125,121.13330078125,121.13330078125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:15,0.2.1 +rope,huggingface,full,memory,MB,T,sequence length,2048,234.14111328125,234.14111328125,234.14111328125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:15,0.2.1 +rope,huggingface,full,memory,MB,T,sequence length,4096,460.15673828125,460.15673828125,460.15673828125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:15,0.2.1 +rope,huggingface,full,memory,MB,T,sequence length,8192,912.18798828125,912.18798828125,912.18798828125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:15,0.2.1 +rope,huggingface,full,memory,MB,T,sequence length,16384,1816.25048828125,1816.25048828125,1816.25048828125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:15,0.2.1 +swiglu,liger,forward,speed,ms,T,sequence length,1024,5.06441593170166,5.06441593170166,5.06441593170166,"{""B"": 4, ""hidden_size"": 4096, ""dtype"": ""torch.bfloat16"", ""intermediate_size"": 11008, ""hidden_act"": ""silu""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:24,0.2.1 +swiglu,liger,forward,speed,ms,T,sequence length,2048,10.075455665588379,10.075455665588379,10.075455665588379,"{""B"": 4, ""hidden_size"": 4096, ""dtype"": ""torch.bfloat16"", ""intermediate_size"": 11008, ""hidden_act"": ""silu""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:24,0.2.1 +swiglu,liger,forward,speed,ms,T,sequence length,4096,18.001951217651367,18.001951217651367,18.001951217651367,"{""B"": 4, ""hidden_size"": 4096, ""dtype"": ""torch.bfloat16"", ""intermediate_size"": 11008, ""hidden_act"": ""silu""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:24,0.2.1 +swiglu,liger,forward,speed,ms,T,sequence length,8192,35.930015563964844,35.930015563964844,35.930015563964844,"{""B"": 4, ""hidden_size"": 4096, ""dtype"": ""torch.bfloat16"", ""intermediate_size"": 11008, ""hidden_act"": ""silu""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:24,0.2.1 +swiglu,huggingface,forward,speed,ms,T,sequence length,1024,4.582320213317871,4.5821757316589355,4.582464218139648,"{""B"": 4, ""hidden_size"": 4096, ""dtype"": ""torch.bfloat16"", ""intermediate_size"": 11008, ""hidden_act"": ""silu""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:28,0.2.1 +swiglu,huggingface,forward,speed,ms,T,sequence length,2048,9.252832412719727,9.252832412719727,9.252832412719727,"{""B"": 4, ""hidden_size"": 4096, ""dtype"": ""torch.bfloat16"", ""intermediate_size"": 11008, ""hidden_act"": ""silu""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:28,0.2.1 +swiglu,huggingface,forward,speed,ms,T,sequence length,4096,18.160255432128906,18.160255432128906,18.160255432128906,"{""B"": 4, ""hidden_size"": 4096, ""dtype"": ""torch.bfloat16"", ""intermediate_size"": 11008, ""hidden_act"": ""silu""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:28,0.2.1 +swiglu,huggingface,forward,speed,ms,T,sequence length,8192,36.2911376953125,36.2911376953125,36.2911376953125,"{""B"": 4, ""hidden_size"": 4096, ""dtype"": ""torch.bfloat16"", ""intermediate_size"": 11008, ""hidden_act"": ""silu""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:28,0.2.1 +swiglu,liger,full,memory,MB,T,sequence length,1024,1100.25,1100.25,1100.25,"{""B"": 4, ""hidden_size"": 4096, ""dtype"": ""torch.bfloat16"", ""intermediate_size"": 11008, ""hidden_act"": ""silu""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:32,0.2.1 +swiglu,liger,full,memory,MB,T,sequence length,2048,1582.25,1582.25,1582.25,"{""B"": 4, ""hidden_size"": 4096, ""dtype"": ""torch.bfloat16"", ""intermediate_size"": 11008, ""hidden_act"": ""silu""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:32,0.2.1 +swiglu,liger,full,memory,MB,T,sequence length,4096,2546.25,2546.25,2546.25,"{""B"": 4, ""hidden_size"": 4096, ""dtype"": ""torch.bfloat16"", ""intermediate_size"": 11008, ""hidden_act"": ""silu""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:32,0.2.1 +swiglu,liger,full,memory,MB,T,sequence length,8192,4474.25,4474.25,4474.25,"{""B"": 4, ""hidden_size"": 4096, ""dtype"": ""torch.bfloat16"", ""intermediate_size"": 11008, ""hidden_act"": ""silu""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:32,0.2.1 +swiglu,huggingface,full,memory,MB,T,sequence length,1024,1294.25,1294.25,1294.25,"{""B"": 4, ""hidden_size"": 4096, ""dtype"": ""torch.bfloat16"", ""intermediate_size"": 11008, ""hidden_act"": ""silu""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:36,0.2.1 +swiglu,huggingface,full,memory,MB,T,sequence length,2048,1992.25,1992.25,1992.25,"{""B"": 4, ""hidden_size"": 4096, ""dtype"": ""torch.bfloat16"", ""intermediate_size"": 11008, ""hidden_act"": ""silu""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:36,0.2.1 +swiglu,huggingface,full,memory,MB,T,sequence length,4096,3452.25,3452.25,3452.25,"{""B"": 4, ""hidden_size"": 4096, ""dtype"": ""torch.bfloat16"", ""intermediate_size"": 11008, ""hidden_act"": ""silu""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:36,0.2.1 +swiglu,huggingface,full,memory,MB,T,sequence length,8192,6372.25,6372.25,6372.25,"{""B"": 4, ""hidden_size"": 4096, ""dtype"": ""torch.bfloat16"", ""intermediate_size"": 11008, ""hidden_act"": ""silu""}",NVIDIA A100-SXM4-80GB,2024-09-03 15:40:36,0.2.1 +kl_div,liger,full,memory,MB,V,vocab size,4096,1536.0009765625,1536.0009765625,1536.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:40,0.2.1 +kl_div,liger,full,memory,MB,V,vocab size,8192,3072.0009765625,3072.0009765625,3072.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:40,0.2.1 +kl_div,liger,full,memory,MB,V,vocab size,16384,6144.0009765625,6144.0009765625,6144.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:40,0.2.1 +kl_div,liger,full,memory,MB,V,vocab size,32768,12288.0009765625,12288.0009765625,12288.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:40,0.2.1 +kl_div,liger,full,memory,MB,V,vocab size,65536,24576.0,24576.0,24576.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:40,0.2.1 +kl_div,liger,full,memory,MB,V,vocab size,131072,49152.0,49152.0,49152.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:40,0.2.1 +kl_div,torch,full,memory,MB,V,vocab size,4096,1792.0,1792.0,1792.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:41,0.2.1 +kl_div,torch,full,memory,MB,V,vocab size,8192,3584.0,3584.0,3584.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:41,0.2.1 +kl_div,torch,full,memory,MB,V,vocab size,16384,7168.0,7168.0,7168.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:41,0.2.1 +kl_div,torch,full,memory,MB,V,vocab size,32768,14336.0,14336.0,14336.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:41,0.2.1 +kl_div,torch,full,memory,MB,V,vocab size,65536,28672.0,28672.0,28672.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:41,0.2.1 +kl_div,torch,full,memory,MB,V,vocab size,131072,57344.0,57344.0,57344.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:41,0.2.1 +kl_div,liger,forward,speed,ms,V,vocab size,4096,0.30640000104904175,0.30563199520111084,0.30745598673820496,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:43,0.2.1 +kl_div,liger,forward,speed,ms,V,vocab size,8192,0.5763360261917114,0.5754943490028381,0.5773376226425171,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:43,0.2.1 +kl_div,liger,forward,speed,ms,V,vocab size,16384,1.1176480054855347,1.1165119409561157,1.1186367273330688,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:43,0.2.1 +kl_div,liger,forward,speed,ms,V,vocab size,32768,2.1987199783325195,2.1970815658569336,2.200934410095215,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:43,0.2.1 +kl_div,liger,forward,speed,ms,V,vocab size,65536,4.356672286987305,4.355186939239502,4.358956813812256,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:43,0.2.1 +kl_div,liger,forward,speed,ms,V,vocab size,131072,8.697919845581055,8.690688133239746,8.703583717346191,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:43,0.2.1 +kl_div,torch,forward,speed,ms,V,vocab size,4096,1.3298559188842773,1.3287359476089478,1.331385612487793,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:45,0.2.1 +kl_div,torch,forward,speed,ms,V,vocab size,8192,2.594543933868408,2.592736005783081,2.596640110015869,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:45,0.2.1 +kl_div,torch,forward,speed,ms,V,vocab size,16384,5.13375997543335,5.1324286460876465,5.1364288330078125,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:45,0.2.1 +kl_div,torch,forward,speed,ms,V,vocab size,32768,10.225567817687988,10.225190162658691,10.227231979370117,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:45,0.2.1 +kl_div,torch,forward,speed,ms,V,vocab size,65536,20.412960052490234,20.411020278930664,20.415000915527344,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:45,0.2.1 +kl_div,torch,forward,speed,ms,V,vocab size,131072,40.818641662597656,40.816402435302734,40.82087707519531,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:45,0.2.1 +kl_div,liger,full,speed,ms,V,vocab size,4096,2.040031909942627,1.9614335298538208,2.192307233810425,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:46,0.2.1 +kl_div,liger,full,speed,ms,V,vocab size,8192,3.866431951522827,3.7955007553100586,3.8693249225616455,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:46,0.2.1 +kl_div,liger,full,speed,ms,V,vocab size,16384,7.261951923370361,7.255136013031006,7.281760215759277,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:46,0.2.1 +kl_div,liger,full,speed,ms,V,vocab size,32768,15.092127799987793,15.07801628112793,15.09660816192627,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:46,0.2.1 +kl_div,liger,full,speed,ms,V,vocab size,65536,29.921375274658203,29.914867401123047,29.921951293945312,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:46,0.2.1 +kl_div,liger,full,speed,ms,V,vocab size,131072,59.70220947265625,59.70220947265625,59.70220947265625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:46,0.2.1 +kl_div,torch,full,speed,ms,V,vocab size,4096,2.8552000522613525,2.852755069732666,2.856454372406006,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:48,0.2.1 +kl_div,torch,full,speed,ms,V,vocab size,8192,5.593632221221924,5.590988636016846,5.594636917114258,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:48,0.2.1 +kl_div,torch,full,speed,ms,V,vocab size,16384,11.124671936035156,11.122162818908691,11.125061988830566,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:48,0.2.1 +kl_div,torch,full,speed,ms,V,vocab size,32768,23.052032470703125,23.050334930419922,23.052589416503906,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:48,0.2.1 +kl_div,torch,full,speed,ms,V,vocab size,65536,46.063167572021484,46.05990219116211,46.06643295288086,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:48,0.2.1 +kl_div,torch,full,speed,ms,V,vocab size,131072,92.06393432617188,92.06393432617188,92.06393432617188,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:48,0.2.1 +jsd,liger,full,memory,MB,V,vocab size,4096,768.0029296875,768.0029296875,768.0029296875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1 +jsd,liger,full,memory,MB,V,vocab size,8192,1536.0029296875,1536.0029296875,1536.0029296875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1 +jsd,liger,full,memory,MB,V,vocab size,16384,3072.0048828125,3072.0048828125,3072.0048828125,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1 +jsd,liger,full,memory,MB,V,vocab size,32768,6144.0087890625,6144.0087890625,6144.0087890625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1 +jsd,liger,full,memory,MB,V,vocab size,65536,12288.0166015625,12288.0166015625,12288.0166015625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1 +jsd,liger,full,memory,MB,V,vocab size,131072,24576.015625,24576.015625,24576.015625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1 +jsd,torch,full,memory,MB,V,vocab size,4096,1664.0009765625,1664.0009765625,1664.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1 +jsd,torch,full,memory,MB,V,vocab size,8192,3328.0009765625,3328.0009765625,3328.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1 +jsd,torch,full,memory,MB,V,vocab size,16384,6656.0009765625,6656.0009765625,6656.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1 +jsd,torch,full,memory,MB,V,vocab size,32768,13312.0009765625,13312.0009765625,13312.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1 +jsd,torch,full,memory,MB,V,vocab size,65536,26624.0,26624.0,26624.0,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1 +jsd,torch,full,memory,MB,V,vocab size,131072,53248.0,53248.0,53248.0,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1 +jsd,liger,forward,speed,ms,V,vocab size,4096,0.4651840031147003,0.4636736214160919,0.4659839868545532,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1 +jsd,liger,forward,speed,ms,V,vocab size,8192,0.927888035774231,0.926751971244812,0.92952960729599,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1 +jsd,liger,forward,speed,ms,V,vocab size,16384,10.96003246307373,10.942886352539062,10.970770835876465,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1 +jsd,liger,forward,speed,ms,V,vocab size,32768,22.405792236328125,22.390380859375,22.41998863220215,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1 +jsd,liger,forward,speed,ms,V,vocab size,65536,43.49095916748047,43.47438049316406,43.50754165649414,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1 +jsd,liger,forward,speed,ms,V,vocab size,131072,87.0363540649414,87.0363540649414,87.0363540649414,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1 +jsd,torch,forward,speed,ms,V,vocab size,4096,2.4744958877563477,2.4725184440612793,2.4764864444732666,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1 +jsd,torch,forward,speed,ms,V,vocab size,8192,4.8528642654418945,4.851238250732422,4.854745864868164,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1 +jsd,torch,forward,speed,ms,V,vocab size,16384,9.532496452331543,9.528634071350098,9.535890579223633,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1 +jsd,torch,forward,speed,ms,V,vocab size,32768,18.91379165649414,18.911853790283203,18.919116973876953,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1 +jsd,torch,forward,speed,ms,V,vocab size,65536,37.70152282714844,37.70074462890625,37.70229721069336,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1 +jsd,torch,forward,speed,ms,V,vocab size,131072,75.37680053710938,75.37680053710938,75.37680053710938,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1 +jsd,liger,full,speed,ms,V,vocab size,4096,1.2074079513549805,1.1739968061447144,1.2760319709777832,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1 +jsd,liger,full,speed,ms,V,vocab size,8192,2.091792106628418,2.0771327018737793,2.106553554534912,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1 +jsd,liger,full,speed,ms,V,vocab size,16384,12.928031921386719,12.8988676071167,12.936230659484863,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1 +jsd,liger,full,speed,ms,V,vocab size,32768,26.55548858642578,26.550823211669922,26.570655822753906,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1 +jsd,liger,full,speed,ms,V,vocab size,65536,51.6833610534668,51.6833610534668,51.6833610534668,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1 +jsd,liger,full,speed,ms,V,vocab size,131072,103.12793731689453,103.12793731689453,103.12793731689453,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1 +jsd,torch,full,speed,ms,V,vocab size,4096,5.397359848022461,5.392876625061035,5.39998722076416,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1 +jsd,torch,full,speed,ms,V,vocab size,8192,10.60153579711914,10.597900390625,10.60470962524414,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1 +jsd,torch,full,speed,ms,V,vocab size,16384,20.9442081451416,20.94247055053711,20.9469051361084,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1 +jsd,torch,full,speed,ms,V,vocab size,32768,42.113216400146484,42.113216400146484,42.113216400146484,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1 +jsd,torch,full,speed,ms,V,vocab size,65536,83.9959716796875,83.9959716796875,83.9959716796875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1 +jsd,torch,full,speed,ms,V,vocab size,131072,167.94175720214844,167.94175720214844,167.94175720214844,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1 +fused_linear_jsd,liger,forward,speed,ms,BT,B x T,1024,110.02185821533203,110.02185821533203,110.02185821533203,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:18,0.3.1 +fused_linear_jsd,liger,forward,speed,ms,BT,B x T,2048,124.14070129394531,124.14070129394531,124.14070129394531,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:18,0.3.1 +fused_linear_jsd,liger,forward,speed,ms,BT,B x T,4096,143.15420532226562,143.15420532226562,143.15420532226562,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:18,0.3.1 +fused_linear_jsd,liger,forward,speed,ms,BT,B x T,8192,180.90406799316406,180.90406799316406,180.90406799316406,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:18,0.3.1 +fused_linear_jsd,torch,forward,speed,ms,BT,B x T,1024,9.556896209716797,9.550745964050293,9.576268196105957,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:20,0.3.1 +fused_linear_jsd,torch,forward,speed,ms,BT,B x T,2048,18.73731231689453,18.732704162597656,18.737701416015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:20,0.3.1 +fused_linear_jsd,torch,forward,speed,ms,BT,B x T,4096,37.830482482910156,37.80821990966797,37.85274124145508,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:20,0.3.1 +fused_linear_jsd,torch,forward,speed,ms,BT,B x T,8192,75.15289306640625,75.15289306640625,75.15289306640625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:20,0.3.1 +fused_linear_jsd,liger,full,speed,ms,BT,B x T,1024,111.16019439697266,111.16019439697266,111.16019439697266,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:24,0.3.1 +fused_linear_jsd,liger,full,speed,ms,BT,B x T,2048,125.6825942993164,125.6825942993164,125.6825942993164,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:24,0.3.1 +fused_linear_jsd,liger,full,speed,ms,BT,B x T,4096,144.00784301757812,144.00784301757812,144.00784301757812,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:24,0.3.1 +fused_linear_jsd,liger,full,speed,ms,BT,B x T,8192,182.5832977294922,182.5832977294922,182.5832977294922,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:24,0.3.1 +fused_linear_jsd,torch,full,speed,ms,BT,B x T,1024,25.977184295654297,25.968351364135742,25.989356994628906,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:27,0.3.1 +fused_linear_jsd,torch,full,speed,ms,BT,B x T,2048,49.48417663574219,49.47330093383789,49.495052337646484,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:27,0.3.1 +fused_linear_jsd,torch,full,speed,ms,BT,B x T,4096,98.31510162353516,98.31510162353516,98.31510162353516,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:27,0.3.1 +fused_linear_jsd,torch,full,speed,ms,BT,B x T,8192,195.29539489746094,195.29539489746094,195.29539489746094,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:27,0.3.1 +fused_linear_jsd,liger,full,memory,MB,BT,B x T,1024,4652.48486328125,4652.48486328125,4652.48486328125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:33,0.3.1 +fused_linear_jsd,liger,full,memory,MB,BT,B x T,2048,5231.93798828125,5231.93798828125,5231.93798828125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:33,0.3.1 +fused_linear_jsd,liger,full,memory,MB,BT,B x T,4096,6391.87548828125,6391.87548828125,6391.87548828125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:33,0.3.1 +fused_linear_jsd,liger,full,memory,MB,BT,B x T,8192,8711.75,8711.75,8711.75,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:33,0.3.1 +fused_linear_jsd,torch,full,memory,MB,BT,B x T,1024,10609.005859375,10609.005859375,10609.005859375,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1 +fused_linear_jsd,torch,full,memory,MB,BT,B x T,2048,17146.009765625,17146.009765625,17146.009765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1 +fused_linear_jsd,torch,full,memory,MB,BT,B x T,4096,30220.017578125,30220.017578125,30220.017578125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1 +fused_linear_jsd,torch,full,memory,MB,BT,B x T,8192,56368.015625,56368.015625,56368.015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1 +fused_linear_orpo_loss,liger,forward,speed,ms,B,B,2,116.00621032714844,116.00621032714844,116.00621032714844,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:24:05,0.4.0 +fused_linear_orpo_loss,liger,forward,speed,ms,B,B,4,230.83609008789062,230.83609008789062,230.83609008789062,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:24:05,0.4.0 +fused_linear_orpo_loss,liger,forward,speed,ms,B,B,8,461.9543151855469,461.9543151855469,461.9543151855469,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:24:05,0.4.0 +fused_linear_orpo_loss,liger,forward,speed,ms,B,B,16,922.994384765625,922.994384765625,922.994384765625,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:24:05,0.4.0 +fused_linear_orpo_loss,huggingface,forward,speed,ms,B,B,2,39.558860778808594,39.52657699584961,39.591148376464844,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:24:36,0.4.0 +fused_linear_orpo_loss,huggingface,forward,speed,ms,B,B,4,79.9734115600586,79.9734115600586,79.9734115600586,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:24:36,0.4.0 +fused_linear_orpo_loss,huggingface,forward,speed,ms,B,B,8,160.071044921875,160.071044921875,160.071044921875,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:24:36,0.4.0 +fused_linear_orpo_loss,huggingface,forward,speed,ms,B,B,16,321.4681091308594,321.4681091308594,321.4681091308594,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:24:36,0.4.0 +fused_linear_orpo_loss,liger,full,speed,ms,B,B,2,116.56009674072266,116.56009674072266,116.56009674072266,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:25:17,0.4.0 +fused_linear_orpo_loss,liger,full,speed,ms,B,B,4,232.43980407714844,232.43980407714844,232.43980407714844,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:25:17,0.4.0 +fused_linear_orpo_loss,liger,full,speed,ms,B,B,8,464.5750732421875,464.5750732421875,464.5750732421875,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:25:17,0.4.0 +fused_linear_orpo_loss,liger,full,speed,ms,B,B,16,926.3385009765625,926.3385009765625,926.3385009765625,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:25:17,0.4.0 +fused_linear_orpo_loss,huggingface,full,speed,ms,B,B,2,120.68428802490234,120.68428802490234,120.68428802490234,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:25:58,0.4.0 +fused_linear_orpo_loss,huggingface,full,speed,ms,B,B,4,241.15061950683594,241.15061950683594,241.15061950683594,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:25:58,0.4.0 +fused_linear_orpo_loss,huggingface,full,speed,ms,B,B,8,492.5342102050781,492.5342102050781,492.5342102050781,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:25:58,0.4.0 +fused_linear_orpo_loss,huggingface,full,speed,ms,B,B,16,1000.8460693359375,1000.8460693359375,1000.8460693359375,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:25:58,0.4.0 +fused_linear_orpo_loss,liger,full,memory,MB,B,B,2,14556.626953125,14556.626953125,14556.626953125,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:26:42,0.4.0 +fused_linear_orpo_loss,liger,full,memory,MB,B,B,4,14748.689453125,14748.689453125,14748.689453125,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:26:42,0.4.0 +fused_linear_orpo_loss,liger,full,memory,MB,B,B,8,15132.814453125,15132.814453125,15132.814453125,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:26:42,0.4.0 +fused_linear_orpo_loss,liger,full,memory,MB,B,B,16,15901.064453125,15901.064453125,15901.064453125,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:26:42,0.4.0 +fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,2,12488.501953125,12488.501953125,12488.501953125,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:27:10,0.4.0 +fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,4,19630.564453125,19630.564453125,19630.564453125,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:27:10,0.4.0 +fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,8,33914.6875,33914.6875,33914.6875,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:27:10,0.4.0 +fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,16,62482.9375,62482.9375,62482.9375,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:27:10,0.4.0 +fused_linear_orpo_loss,liger,forward,speed,ms,B,B,2,31.02783966064453,31.027551651000977,31.164947509765625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:06:30,0.4.0 +fused_linear_orpo_loss,liger,forward,speed,ms,B,B,4,60.88966369628906,60.88966369628906,60.88966369628906,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:06:30,0.4.0 +fused_linear_orpo_loss,liger,forward,speed,ms,B,B,8,121.08070373535156,121.08070373535156,121.08070373535156,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:06:30,0.4.0 +fused_linear_orpo_loss,liger,forward,speed,ms,B,B,16,244.36968994140625,244.36968994140625,244.36968994140625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:06:30,0.4.0 +fused_linear_orpo_loss,huggingface,forward,speed,ms,B,B,2,12.9093599319458,12.874624252319336,12.947936058044434,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:06:57,0.4.0 +fused_linear_orpo_loss,huggingface,forward,speed,ms,B,B,4,25.557632446289062,25.526700973510742,25.703763961791992,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:06:57,0.4.0 +fused_linear_orpo_loss,huggingface,forward,speed,ms,B,B,8,51.75590515136719,51.75590515136719,51.75590515136719,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:06:57,0.4.0 +fused_linear_orpo_loss,huggingface,forward,speed,ms,B,B,16,103.8515853881836,103.8515853881836,103.8515853881836,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:06:57,0.4.0 +fused_linear_orpo_loss,liger,full,speed,ms,B,B,2,32.52537536621094,32.49258041381836,32.558170318603516,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:07:28,0.4.0 +fused_linear_orpo_loss,liger,full,speed,ms,B,B,4,63.16300964355469,63.16300964355469,63.16300964355469,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:07:28,0.4.0 +fused_linear_orpo_loss,liger,full,speed,ms,B,B,8,123.02518463134766,123.02518463134766,123.02518463134766,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:07:28,0.4.0 +fused_linear_orpo_loss,liger,full,speed,ms,B,B,16,247.44105529785156,247.44105529785156,247.44105529785156,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:07:28,0.4.0 +fused_linear_orpo_loss,huggingface,full,speed,ms,B,B,2,39.32752227783203,39.32701873779297,39.32802200317383,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:07:59,0.4.0 +fused_linear_orpo_loss,huggingface,full,speed,ms,B,B,4,77.9202880859375,77.9202880859375,77.9202880859375,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:07:59,0.4.0 +fused_linear_orpo_loss,huggingface,full,speed,ms,B,B,8,151.6084442138672,151.6084442138672,151.6084442138672,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:07:59,0.4.0 +fused_linear_orpo_loss,huggingface,full,speed,ms,B,B,16,304.4580993652344,304.4580993652344,304.4580993652344,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:07:59,0.4.0 +fused_linear_orpo_loss,liger,full,memory,MB,B,B,2,8161.34619140625,8161.34619140625,8161.34619140625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:30,0.4.0 +fused_linear_orpo_loss,liger,full,memory,MB,B,B,4,8209.361328125,8209.361328125,8209.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:30,0.4.0 +fused_linear_orpo_loss,liger,full,memory,MB,B,B,8,8305.392578125,8305.392578125,8305.392578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:30,0.4.0 +fused_linear_orpo_loss,liger,full,memory,MB,B,B,16,8497.455078125,8497.455078125,8497.455078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:30,0.4.0 +fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.314453125,8645.314453125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:56,0.4.0 +fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:56,0.4.0 +fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:56,0.4.0 +fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,16,33418.421875,33418.421875,33418.421875,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:56,0.4.0 +fused_linear_cpo_loss,liger,forward,speed,ms,B,B,2,31.536447525024414,31.457439422607422,31.543052673339844,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:54:47,0.4.1 +fused_linear_cpo_loss,liger,forward,speed,ms,B,B,4,62.407745361328125,62.407745361328125,62.407745361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:54:47,0.4.1 +fused_linear_cpo_loss,liger,forward,speed,ms,B,B,8,123.64259338378906,123.64259338378906,123.64259338378906,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:54:47,0.4.1 +fused_linear_cpo_loss,liger,forward,speed,ms,B,B,16,245.66575622558594,245.66575622558594,245.66575622558594,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:54:47,0.4.1 +fused_linear_cpo_loss,huggingface,forward,speed,ms,B,B,2,14.516239166259766,14.514080047607422,14.52575969696045,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:55:20,0.4.1 +fused_linear_cpo_loss,huggingface,forward,speed,ms,B,B,4,26.087743759155273,25.943340301513672,26.269376754760742,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:55:20,0.4.1 +fused_linear_cpo_loss,huggingface,forward,speed,ms,B,B,8,51.85932922363281,51.85932922363281,51.85932922363281,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:55:20,0.4.1 +fused_linear_cpo_loss,huggingface,forward,speed,ms,B,B,16,104.99673461914062,104.99673461914062,104.99673461914062,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:55:20,0.4.1 +fused_linear_cpo_loss,liger,full,speed,ms,B,B,2,33.309967041015625,33.21604919433594,33.40388488769531,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:55:55,0.4.1 +fused_linear_cpo_loss,liger,full,speed,ms,B,B,4,63.053470611572266,63.053470611572266,63.053470611572266,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:55:55,0.4.1 +fused_linear_cpo_loss,liger,full,speed,ms,B,B,8,125.53849792480469,125.53849792480469,125.53849792480469,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:55:55,0.4.1 +fused_linear_cpo_loss,liger,full,speed,ms,B,B,16,250.22178649902344,250.22178649902344,250.22178649902344,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:55:55,0.4.1 +fused_linear_cpo_loss,huggingface,full,speed,ms,B,B,2,39.45849609375,39.33102798461914,39.58596420288086,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:56:30,0.4.1 +fused_linear_cpo_loss,huggingface,full,speed,ms,B,B,4,77.00272369384766,77.00272369384766,77.00272369384766,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:56:30,0.4.1 +fused_linear_cpo_loss,huggingface,full,speed,ms,B,B,8,154.28419494628906,154.28419494628906,154.28419494628906,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:56:30,0.4.1 +fused_linear_cpo_loss,huggingface,full,speed,ms,B,B,16,309.23162841796875,309.23162841796875,309.23162841796875,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:56:30,0.4.1 +fused_linear_cpo_loss,liger,full,memory,MB,B,B,2,8161.34619140625,8161.34619140625,8161.34619140625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:06,0.4.1 +fused_linear_cpo_loss,liger,full,memory,MB,B,B,4,8209.361328125,8209.361328125,8209.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:06,0.4.1 +fused_linear_cpo_loss,liger,full,memory,MB,B,B,8,8305.392578125,8305.392578125,8305.392578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:06,0.4.1 +fused_linear_cpo_loss,liger,full,memory,MB,B,B,16,8497.455078125,8497.455078125,8497.455078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:06,0.4.1 +fused_linear_cpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.314453125,8645.314453125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:37,0.4.1 +fused_linear_cpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:37,0.4.1 +fused_linear_cpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:37,0.4.1 +fused_linear_cpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:37,0.4.1 +fused_linear_simpo_loss,liger,forward,speed,ms,B,B,2,30.28438377380371,30.107013702392578,30.284786224365234,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:26,0.4.1 +fused_linear_simpo_loss,liger,forward,speed,ms,B,B,4,58.80876922607422,58.80876922607422,58.80876922607422,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:26,0.4.1 +fused_linear_simpo_loss,liger,forward,speed,ms,B,B,8,117.96163177490234,117.96163177490234,117.96163177490234,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:26,0.4.1 +fused_linear_simpo_loss,liger,forward,speed,ms,B,B,16,235.60794067382812,235.60794067382812,235.60794067382812,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:26,0.4.1 +fused_linear_simpo_loss,huggingface,forward,speed,ms,B,B,2,14.513839721679688,14.510687828063965,14.517855644226074,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:56,0.4.1 +fused_linear_simpo_loss,huggingface,forward,speed,ms,B,B,4,28.78099250793457,28.72719383239746,28.792186737060547,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:56,0.4.1 +fused_linear_simpo_loss,huggingface,forward,speed,ms,B,B,8,52.5733757019043,52.5733757019043,52.5733757019043,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:56,0.4.1 +fused_linear_simpo_loss,huggingface,forward,speed,ms,B,B,16,104.44764709472656,104.44764709472656,104.44764709472656,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:56,0.4.1 +fused_linear_simpo_loss,liger,full,speed,ms,B,B,2,31.566062927246094,31.457612991333008,31.674514770507812,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:28:27,0.4.1 +fused_linear_simpo_loss,liger,full,speed,ms,B,B,4,61.4403190612793,61.4403190612793,61.4403190612793,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:28:27,0.4.1 +fused_linear_simpo_loss,liger,full,speed,ms,B,B,8,119.97705841064453,119.97705841064453,119.97705841064453,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:28:27,0.4.1 +fused_linear_simpo_loss,liger,full,speed,ms,B,B,16,238.13417053222656,238.13417053222656,238.13417053222656,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:28:27,0.4.1 +fused_linear_simpo_loss,huggingface,full,speed,ms,B,B,2,39.811119079589844,39.65474319458008,39.96749496459961,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:00,0.4.1 +fused_linear_simpo_loss,huggingface,full,speed,ms,B,B,4,77.20928192138672,77.20928192138672,77.20928192138672,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:00,0.4.1 +fused_linear_simpo_loss,huggingface,full,speed,ms,B,B,8,153.6952667236328,153.6952667236328,153.6952667236328,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:00,0.4.1 +fused_linear_simpo_loss,huggingface,full,speed,ms,B,B,16,307.7382507324219,307.7382507324219,307.7382507324219,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:00,0.4.1 +fused_linear_simpo_loss,liger,full,memory,MB,B,B,2,7675.3291015625,7675.3291015625,7675.3291015625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:33,0.4.1 +fused_linear_simpo_loss,liger,full,memory,MB,B,B,4,7723.3447265625,7723.3447265625,7723.3447265625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:33,0.4.1 +fused_linear_simpo_loss,liger,full,memory,MB,B,B,8,7819.3759765625,7819.3759765625,7819.3759765625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:33,0.4.1 +fused_linear_simpo_loss,liger,full,memory,MB,B,B,16,8011.4384765625,8011.4384765625,8011.4384765625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:33,0.4.1 +fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.314453125,8645.314453125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 +fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 +fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 +fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 +distill_jsd_loss,liger,forward,speed,ms,BT,B x T,1024,7.735536098480225,7.729177474975586,7.798131465911865,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2 +distill_jsd_loss,liger,forward,speed,ms,BT,B x T,2048,15.20411205291748,15.165056228637695,15.226079940795898,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2 +distill_jsd_loss,liger,forward,speed,ms,BT,B x T,4096,30.159456253051758,30.126911163330078,30.165311813354492,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2 +distill_jsd_loss,liger,forward,speed,ms,BT,B x T,8192,60.24163055419922,60.24163055419922,60.24163055419922,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2 +distill_jsd_loss,torch,forward,speed,ms,BT,B x T,1024,10.906111717224121,10.903244972229004,10.91296672821045,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:18,0.4.2 +distill_jsd_loss,torch,forward,speed,ms,BT,B x T,2048,21.480207443237305,21.465139389038086,21.489286422729492,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:18,0.4.2 +distill_jsd_loss,torch,forward,speed,ms,BT,B x T,4096,42.96339416503906,42.96237564086914,42.96440887451172,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:18,0.4.2 +distill_jsd_loss,torch,forward,speed,ms,BT,B x T,8192,85.3946533203125,85.3946533203125,85.3946533203125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:18,0.4.2 +distill_jsd_loss,liger,full,speed,ms,BT,B x T,1024,8.312895774841309,8.310400009155273,8.326751708984375,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:51,0.4.2 +distill_jsd_loss,liger,full,speed,ms,BT,B x T,2048,15.770208358764648,15.767775535583496,15.774784088134766,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:51,0.4.2 +distill_jsd_loss,liger,full,speed,ms,BT,B x T,4096,30.922752380371094,30.920312881469727,30.927898406982422,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:51,0.4.2 +distill_jsd_loss,liger,full,speed,ms,BT,B x T,8192,60.70627212524414,60.70627212524414,60.70627212524414,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:51,0.4.2 +distill_jsd_loss,torch,full,speed,ms,BT,B x T,1024,28.72480010986328,28.718809127807617,28.728179931640625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:25,0.4.2 +distill_jsd_loss,torch,full,speed,ms,BT,B x T,2048,54.281761169433594,54.281761169433594,54.281761169433594,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:25,0.4.2 +distill_jsd_loss,torch,full,speed,ms,BT,B x T,4096,107.08905792236328,107.08905792236328,107.08905792236328,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:25,0.4.2 +distill_jsd_loss,torch,full,speed,ms,BT,B x T,8192,213.1598663330078,213.1598663330078,213.1598663330078,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:25,0.4.2 +distill_jsd_loss,liger,full,memory,MB,BT,B x T,1024,10913.541015625,10913.541015625,10913.541015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:58,0.4.2 +distill_jsd_loss,liger,full,memory,MB,BT,B x T,2048,10941.548828125,10941.548828125,10941.548828125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:58,0.4.2 +distill_jsd_loss,liger,full,memory,MB,BT,B x T,4096,10997.564453125,10997.564453125,10997.564453125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:58,0.4.2 +distill_jsd_loss,liger,full,memory,MB,BT,B x T,8192,11109.595703125,11109.595703125,11109.595703125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:58,0.4.2 +distill_jsd_loss,torch,full,memory,MB,BT,B x T,1024,16174.0390625,16174.0390625,16174.0390625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2 +distill_jsd_loss,torch,full,memory,MB,BT,B x T,2048,23713.05078125,23713.05078125,23713.05078125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2 +distill_jsd_loss,torch,full,memory,MB,BT,B x T,4096,38791.07421875,38791.07421875,38791.07421875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2 +distill_jsd_loss,torch,full,memory,MB,BT,B x T,8192,68947.1015625,68947.1015625,68947.1015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2 +kto_loss,liger,forward,speed,ms,B,Batch Size (B),2,3.9951679706573486,3.991487979888916,4.002252578735352,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:22:44,0.5.4 +kto_loss,liger,forward,speed,ms,B,Batch Size (B),4,7.8037919998168945,7.788575649261475,7.808595180511475,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:22:44,0.5.4 +kto_loss,liger,forward,speed,ms,B,Batch Size (B),8,15.43172836303711,15.430015563964844,15.4335355758667,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:22:44,0.5.4 +kto_loss,liger,forward,speed,ms,B,Batch Size (B),16,30.66864013671875,30.66431999206543,30.670501708984375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:22:44,0.5.4 +kto_loss,liger,forward,speed,ms,B,Batch Size (B),32,61.1163215637207,61.1163215637207,61.1163215637207,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:22:44,0.5.4 +kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),2,3.8766400814056396,3.8680384159088135,3.8897151947021484,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:01,0.5.4 +kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),4,7.213727951049805,7.206470489501953,7.229574680328369,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:01,0.5.4 +kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),8,13.828800201416016,13.810944557189941,13.834943771362305,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:01,0.5.4 +kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),16,27.0930233001709,27.08517074584961,27.09713363647461,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:01,0.5.4 +kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),32,54.13715362548828,54.13715362548828,54.13715362548828,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:01,0.5.4 +kto_loss,liger,full,speed,ms,B,Batch Size (B),2,4.782928466796875,4.677459239959717,5.3430914878845215,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:18,0.5.4 +kto_loss,liger,full,speed,ms,B,Batch Size (B),4,8.517248153686523,8.481344223022461,8.561504364013672,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:18,0.5.4 +kto_loss,liger,full,speed,ms,B,Batch Size (B),8,16.547504425048828,16.513471603393555,16.678144454956055,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:18,0.5.4 +kto_loss,liger,full,speed,ms,B,Batch Size (B),16,31.891263961791992,31.819705963134766,32.274131774902344,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:18,0.5.4 +kto_loss,liger,full,speed,ms,B,Batch Size (B),32,62.953758239746094,62.953758239746094,62.953758239746094,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:18,0.5.4 +kto_loss,huggingface,full,speed,ms,B,Batch Size (B),2,6.201632022857666,6.163315296173096,6.314668655395508,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:35,0.5.4 +kto_loss,huggingface,full,speed,ms,B,Batch Size (B),4,11.156224250793457,11.142304420471191,11.207296371459961,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:35,0.5.4 +kto_loss,huggingface,full,speed,ms,B,Batch Size (B),8,21.249855041503906,21.231891632080078,21.264543533325195,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:35,0.5.4 +kto_loss,huggingface,full,speed,ms,B,Batch Size (B),16,41.55686569213867,41.536956787109375,41.57677459716797,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:35,0.5.4 +kto_loss,huggingface,full,speed,ms,B,Batch Size (B),32,81.56924438476562,81.56924438476562,81.56924438476562,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:35,0.5.4 +kto_loss,liger,full,memory,MB,B,Batch Size (B),2,2585.73876953125,2585.73876953125,2585.73876953125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:55,0.5.4 +kto_loss,liger,full,memory,MB,B,Batch Size (B),4,3348.9892578125,3348.9892578125,3348.9892578125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:55,0.5.4 +kto_loss,liger,full,memory,MB,B,Batch Size (B),8,3361.0048828125,3361.0048828125,3361.0048828125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:55,0.5.4 +kto_loss,liger,full,memory,MB,B,Batch Size (B),16,3385.0361328125,3385.0361328125,3385.0361328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:55,0.5.4 +kto_loss,liger,full,memory,MB,B,Batch Size (B),32,3433.0986328125,3433.0986328125,3433.0986328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:55,0.5.4 +kto_loss,huggingface,full,memory,MB,B,Batch Size (B),2,4341.74951171875,4341.74951171875,4341.74951171875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4 +kto_loss,huggingface,full,memory,MB,B,Batch Size (B),4,6099.26513671875,6099.26513671875,6099.26513671875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4 +kto_loss,huggingface,full,memory,MB,B,Batch Size (B),8,9613.298828125,9613.298828125,9613.298828125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4 +kto_loss,huggingface,full,memory,MB,B,Batch Size (B),16,16643.365234375,16643.365234375,16643.365234375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4 +kto_loss,huggingface,full,memory,MB,B,Batch Size (B),32,30703.498046875,30703.498046875,30703.498046875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4 +sparsemax,liger,forward,speed,ms,V,feature size,1024,0.41471999883651733,0.4126720130443573,0.42393600940704346,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8 +sparsemax,liger,forward,speed,ms,V,feature size,2048,0.7608320116996765,0.7598080039024353,0.7628800272941589,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8 +sparsemax,liger,forward,speed,ms,V,feature size,4096,1.4561280012130737,1.4540799856185913,1.4581760168075562,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8 +sparsemax,liger,forward,speed,ms,V,feature size,8192,5.288959980010986,5.2848639488220215,5.29986572265625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8 +sparsemax,liger,forward,speed,ms,V,feature size,16384,10.734624862670898,10.729472160339355,11.096882820129395,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8 +sparsemax,liger,forward,speed,ms,V,feature size,32768,21.729312896728516,21.7128963470459,22.20728302001953,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8 +sparsemax,torch,forward,speed,ms,V,feature size,1024,0.42291200160980225,0.42188799381256104,0.42393600940704346,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8 +sparsemax,torch,forward,speed,ms,V,feature size,2048,0.7782400250434875,0.7772160172462463,0.779263973236084,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8 +sparsemax,torch,forward,speed,ms,V,feature size,4096,1.4940160512924194,1.491968035697937,1.4960639476776123,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8 +sparsemax,torch,forward,speed,ms,V,feature size,8192,5.359615802764893,5.356544017791748,5.366579055786133,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8 +sparsemax,torch,forward,speed,ms,V,feature size,16384,10.883584022521973,10.874879837036133,11.224268913269043,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8 +sparsemax,torch,forward,speed,ms,V,feature size,32768,22.19878387451172,22.018457412719727,22.48888397216797,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8 +sparsemax,liger,full,speed,ms,V,feature size,1024,0.4558719992637634,0.45558398962020874,0.45772799849510193,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8 +sparsemax,liger,full,speed,ms,V,feature size,2048,0.8488960266113281,0.8478720188140869,0.8509439826011658,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8 +sparsemax,liger,full,speed,ms,V,feature size,4096,1.6476160287857056,1.6465920209884644,1.6499264240264893,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8 +sparsemax,liger,full,speed,ms,V,feature size,8192,5.664768218994141,5.660672187805176,5.681356906890869,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8 +sparsemax,liger,full,speed,ms,V,feature size,16384,11.486207962036133,11.478015899658203,11.874713897705078,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8 +sparsemax,liger,full,speed,ms,V,feature size,32768,23.457279205322266,23.289682388305664,23.76642608642578,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8 +sparsemax,torch,full,speed,ms,V,feature size,1024,0.6021119952201843,0.6010879874229431,0.6041600108146667,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8 +sparsemax,torch,full,speed,ms,V,feature size,2048,1.1212799549102783,1.119264006614685,1.1223039627075195,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8 +sparsemax,torch,full,speed,ms,V,feature size,4096,2.1637120246887207,2.1616640090942383,2.165760040283203,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8 +sparsemax,torch,full,speed,ms,V,feature size,8192,6.693888187408447,6.68723201751709,6.705561637878418,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8 +sparsemax,torch,full,speed,ms,V,feature size,16384,13.523456573486328,13.518848419189453,13.878681182861328,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8 +sparsemax,torch,full,speed,ms,V,feature size,32768,27.604991912841797,27.295129776000977,27.77518081665039,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8 +sparsemax,liger,backward,speed,ms,V,feature size,1024,0.04403200000524521,0.043007999658584595,0.05222399905323982,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8 +sparsemax,liger,backward,speed,ms,V,feature size,2048,0.08806400001049042,0.08713600039482117,0.08806400001049042,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8 +sparsemax,liger,backward,speed,ms,V,feature size,4096,0.1884160041809082,0.1884160041809082,0.18943999707698822,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8 +sparsemax,liger,backward,speed,ms,V,feature size,8192,0.374783992767334,0.37376001477241516,0.37486720085144043,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8 +sparsemax,liger,backward,speed,ms,V,feature size,16384,0.7516160011291504,0.7505919933319092,0.7516160011291504,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8 +sparsemax,liger,backward,speed,ms,V,feature size,32768,1.5738879442214966,1.572864055633545,1.575935959815979,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8 +sparsemax,torch,backward,speed,ms,V,feature size,1024,0.1812479943037033,0.1802240014076233,0.18227200210094452,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8 +sparsemax,torch,backward,speed,ms,V,feature size,2048,0.34406399726867676,0.34406399726867676,0.34508800506591797,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8 +sparsemax,torch,backward,speed,ms,V,feature size,4096,0.6717439889907837,0.6707199811935425,0.6727679967880249,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8 +sparsemax,torch,backward,speed,ms,V,feature size,8192,1.3250559568405151,1.3241215944290161,1.3260799646377563,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8 +sparsemax,torch,backward,speed,ms,V,feature size,16384,2.629631996154785,2.628607988357544,2.6306560039520264,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8 +sparsemax,torch,backward,speed,ms,V,feature size,32768,5.236735820770264,5.235712051391602,5.239808082580566,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8 +sparsemax,liger,full,memory,MB,V,feature size,1024,82.03515625,82.03515625,82.03515625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8 +sparsemax,liger,full,memory,MB,V,feature size,2048,164.0390625,164.0390625,164.0390625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8 +sparsemax,liger,full,memory,MB,V,feature size,4096,328.046875,328.046875,328.046875,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8 +sparsemax,liger,full,memory,MB,V,feature size,8192,704.00048828125,704.00048828125,704.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8 +sparsemax,liger,full,memory,MB,V,feature size,16384,1408.00048828125,1408.00048828125,1408.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8 +sparsemax,liger,full,memory,MB,V,feature size,32768,2816.00048828125,2816.00048828125,2816.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8 +sparsemax,torch,full,memory,MB,V,feature size,1024,82.03515625,82.03515625,82.03515625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8 +sparsemax,torch,full,memory,MB,V,feature size,2048,164.0390625,164.0390625,164.0390625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8 +sparsemax,torch,full,memory,MB,V,feature size,4096,328.046875,328.046875,328.046875,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8 +sparsemax,torch,full,memory,MB,V,feature size,8192,704.00048828125,704.00048828125,704.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8 +sparsemax,torch,full,memory,MB,V,feature size,16384,1408.00048828125,1408.00048828125,1408.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8 +sparsemax,torch,full,memory,MB,V,feature size,32768,2816.00048828125,2816.00048828125,2816.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8 +multi_token_attention,liger,forward,speed,ms,L,sequence length,32,0.01740800030529499,0.01740800030529499,0.018432000651955605,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:10,0.1.1 +multi_token_attention,liger,forward,speed,ms,L,sequence length,64,0.018432000651955605,0.01740800030529499,0.01945599913597107,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:10,0.1.1 +multi_token_attention,liger,forward,speed,ms,L,sequence length,128,0.023552000522613525,0.02252800017595291,0.02364799939095974,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:10,0.1.1 +multi_token_attention,liger,forward,speed,ms,L,sequence length,256,0.043007999658584595,0.04198399931192398,0.043007999658584595,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:10,0.1.1 +multi_token_attention,liger,forward,speed,ms,L,sequence length,512,0.12595200538635254,0.12492799758911133,0.12595200538635254,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:10,0.1.1 +multi_token_attention,liger,forward,speed,ms,L,sequence length,1024,0.5283839702606201,0.5253120064735413,0.5294079780578613,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:10,0.1.1 +multi_token_attention,torch,forward,speed,ms,L,sequence length,32,0.2467840015888214,0.24063999950885773,0.2529279887676239,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:11,0.1.1 +multi_token_attention,torch,forward,speed,ms,L,sequence length,64,0.24166400730609894,0.23756800591945648,0.24883200228214264,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:11,0.1.1 +multi_token_attention,torch,forward,speed,ms,L,sequence length,128,0.24268800020217896,0.2385600060224533,0.24985599517822266,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:11,0.1.1 +multi_token_attention,torch,forward,speed,ms,L,sequence length,256,0.24166400730609894,0.23873919248580933,0.24782079458236694,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:11,0.1.1 +multi_token_attention,torch,forward,speed,ms,L,sequence length,512,0.31334400177001953,0.3102720081806183,0.3213888108730316,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:11,0.1.1 +multi_token_attention,torch,forward,speed,ms,L,sequence length,1024,0.719871997833252,0.7167999744415283,0.7260159850120544,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:11,0.1.1 +multi_token_attention,liger,full,speed,ms,L,sequence length,32,0.9349120259284973,0.6543359756469727,0.9494400024414062,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:12,0.1.1 +multi_token_attention,liger,full,speed,ms,L,sequence length,64,0.6215680241584778,0.5631999969482422,0.8916991949081421,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:12,0.1.1 +multi_token_attention,liger,full,speed,ms,L,sequence length,128,0.5406720042228699,0.5335040092468262,0.550003170967102,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:12,0.1.1 +multi_token_attention,liger,full,speed,ms,L,sequence length,256,0.5631999969482422,0.5560320019721985,0.5674688220024109,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:12,0.1.1 +multi_token_attention,liger,full,speed,ms,L,sequence length,512,0.6430720090866089,0.6420480012893677,0.6430720090866089,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:12,0.1.1 +multi_token_attention,liger,full,speed,ms,L,sequence length,1024,2.4780800342559814,2.4770560264587402,2.479987144470215,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:12,0.1.1 +multi_token_attention,torch,full,speed,ms,L,sequence length,32,0.795199990272522,0.78438401222229,0.8038399815559387,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:13,0.1.1 +multi_token_attention,torch,full,speed,ms,L,sequence length,64,0.7362560033798218,0.6504960060119629,0.7464960217475891,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:13,0.1.1 +multi_token_attention,torch,full,speed,ms,L,sequence length,128,0.7680000066757202,0.6437439918518066,0.8105729818344116,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:13,0.1.1 +multi_token_attention,torch,full,speed,ms,L,sequence length,256,0.7685279846191406,0.7586879730224609,0.783519983291626,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:13,0.1.1 +multi_token_attention,torch,full,speed,ms,L,sequence length,512,0.9676799774169922,0.9625599980354309,0.9751039743423462,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:13,0.1.1 +multi_token_attention,torch,full,speed,ms,L,sequence length,1024,2.772480010986328,2.7688961029052734,2.7842559814453125,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:13,0.1.1 +multi_token_attention,liger,backward,speed,ms,L,sequence length,32,0.334879994392395,0.3222528100013733,0.6912000179290771,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:14,0.1.1 +multi_token_attention,liger,backward,speed,ms,L,sequence length,64,0.23756800591945648,0.228166401386261,0.2629631757736206,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:14,0.1.1 +multi_token_attention,liger,backward,speed,ms,L,sequence length,128,0.29785600304603577,0.2519040107727051,0.3081727921962738,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:14,0.1.1 +multi_token_attention,liger,backward,speed,ms,L,sequence length,256,0.2590720057487488,0.24391679465770721,0.30832639336586,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:14,0.1.1 +multi_token_attention,liger,backward,speed,ms,L,sequence length,512,0.5171200037002563,0.5169600248336792,0.5181440114974976,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:14,0.1.1 +multi_token_attention,liger,backward,speed,ms,L,sequence length,1024,1.9578880071640015,1.9568639993667603,1.9615744352340698,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:14,0.1.1 +multi_token_attention,torch,backward,speed,ms,L,sequence length,32,0.09830400347709656,0.08908800035715103,0.20353920757770538,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:15,0.1.1 +multi_token_attention,torch,backward,speed,ms,L,sequence length,64,0.06348799914121628,0.062463998794555664,0.06348799914121628,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:15,0.1.1 +multi_token_attention,torch,backward,speed,ms,L,sequence length,128,0.09011200070381165,0.08908800035715103,0.09011200070381165,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:15,0.1.1 +multi_token_attention,torch,backward,speed,ms,L,sequence length,256,0.16383999586105347,0.16383999586105347,0.16486400365829468,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:15,0.1.1 +multi_token_attention,torch,backward,speed,ms,L,sequence length,512,0.52019202709198,0.5191680192947388,0.52019202709198,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:15,0.1.1 +multi_token_attention,torch,backward,speed,ms,L,sequence length,1024,1.9763200283050537,1.9752960205078125,1.9763200283050537,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:15,0.1.1 +multi_token_attention,liger,full,memory,MB,L,sequence length,32,0.97412109375,0.97412109375,0.97412109375,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:15,0.1.1 +multi_token_attention,liger,full,memory,MB,L,sequence length,64,1.53662109375,1.53662109375,1.53662109375,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:15,0.1.1 +multi_token_attention,liger,full,memory,MB,L,sequence length,128,3.69287109375,3.69287109375,3.69287109375,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:15,0.1.1 +multi_token_attention,liger,full,memory,MB,L,sequence length,256,13.068359375,13.068359375,13.068359375,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:15,0.1.1 +multi_token_attention,liger,full,memory,MB,L,sequence length,512,48.974609375,48.974609375,48.974609375,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:15,0.1.1 +multi_token_attention,liger,full,memory,MB,L,sequence length,1024,192.974609375,192.974609375,192.974609375,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:15,0.1.1 +multi_token_attention,torch,full,memory,MB,L,sequence length,32,0.9599609375,0.9599609375,0.9599609375,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:15,0.1.1 +multi_token_attention,torch,full,memory,MB,L,sequence length,64,1.4814453125,1.4814453125,1.4814453125,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:15,0.1.1 +multi_token_attention,torch,full,memory,MB,L,sequence length,128,3.4736328125,3.4736328125,3.4736328125,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:15,0.1.1 +multi_token_attention,torch,full,memory,MB,L,sequence length,256,12.19287109375,12.19287109375,12.19287109375,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:15,0.1.1 +multi_token_attention,torch,full,memory,MB,L,sequence length,512,45.47412109375,45.47412109375,45.47412109375,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:15,0.1.1 +multi_token_attention,torch,full,memory,MB,L,sequence length,1024,178.97412109375,178.97412109375,178.97412109375,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-28 04:46:15,0.1.1 +softmax,liger,forward,speed,ms,N,hidden size,128,0.0071680000983178616,0.0071680000983178616,0.007942399941384792,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:04,0.5.8 +softmax,liger,forward,speed,ms,N,hidden size,256,0.008448000065982342,0.008191999979317188,0.009216000325977802,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:04,0.5.8 +softmax,liger,forward,speed,ms,N,hidden size,512,0.013311999849975109,0.01228800043463707,0.013311999849975109,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:04,0.5.8 +softmax,liger,forward,speed,ms,N,hidden size,1024,0.021503999829292297,0.021503999829292297,0.02252800017595291,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:04,0.5.8 +softmax,liger,forward,speed,ms,N,hidden size,2048,0.04095999896526337,0.04095999896526337,0.04198399931192398,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:04,0.5.8 +softmax,liger,forward,speed,ms,N,hidden size,4096,0.0798719972372055,0.0798719972372055,0.08089599758386612,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:04,0.5.8 +softmax,torch,forward,speed,ms,N,hidden size,128,0.006144000217318535,0.006144000217318535,0.0071680000983178616,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:07,0.5.8 +softmax,torch,forward,speed,ms,N,hidden size,256,0.008191999979317188,0.008191999979317188,0.009216000325977802,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:07,0.5.8 +softmax,torch,forward,speed,ms,N,hidden size,512,0.01228800043463707,0.01228800043463707,0.013311999849975109,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:07,0.5.8 +softmax,torch,forward,speed,ms,N,hidden size,1024,0.02252800017595291,0.02252800017595291,0.023552000522613525,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:07,0.5.8 +softmax,torch,forward,speed,ms,N,hidden size,2048,0.057583998888731,0.05734400078654289,0.058368001133203506,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:07,0.5.8 +softmax,torch,forward,speed,ms,N,hidden size,4096,0.08323200047016144,0.08294399827718735,0.08396799862384796,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:07,0.5.8 +softmax,liger,full,speed,ms,N,hidden size,128,0.053247999399900436,0.04505600035190582,0.06172160431742668,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:10,0.5.8 +softmax,liger,full,speed,ms,N,hidden size,256,0.05939200147986412,0.04198399931192398,0.11169920116662979,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:10,0.5.8 +softmax,liger,full,speed,ms,N,hidden size,512,0.11577600240707397,0.07720960676670074,0.16793599724769592,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:10,0.5.8 +softmax,liger,full,speed,ms,N,hidden size,1024,0.12492799758911133,0.10273279249668121,0.2982015907764435,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:10,0.5.8 +softmax,liger,full,speed,ms,N,hidden size,2048,0.1013759970664978,0.10035199671983719,0.12902399897575378,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:10,0.5.8 +softmax,liger,full,speed,ms,N,hidden size,4096,0.19660800695419312,0.19660800695419312,0.19763199985027313,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:10,0.5.8 +softmax,torch,full,speed,ms,N,hidden size,128,0.013311999849975109,0.013311999849975109,0.013504000380635262,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:13,0.5.8 +softmax,torch,full,speed,ms,N,hidden size,256,0.019152000546455383,0.018432000651955605,0.01945599913597107,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:13,0.5.8 +softmax,torch,full,speed,ms,N,hidden size,512,0.03891199827194214,0.03788800165057182,0.03891199827194214,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:13,0.5.8 +softmax,torch,full,speed,ms,N,hidden size,1024,0.08396799862384796,0.08396799862384796,0.08499199897050858,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:13,0.5.8 +softmax,torch,full,speed,ms,N,hidden size,2048,0.18329599499702454,0.18329599499702454,0.18432000279426575,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:13,0.5.8 +softmax,torch,full,speed,ms,N,hidden size,4096,0.3307519853115082,0.32972800731658936,0.33169281482696533,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:13,0.5.8 +softmax,liger,forward,speed,ms,N,hidden size,128,0.006335999816656113,0.006144000217318535,0.0071680000983178616,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:16,0.5.8 +softmax,liger,forward,speed,ms,N,hidden size,256,0.0071680000983178616,0.006144000217318535,0.0071680000983178616,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:16,0.5.8 +softmax,liger,forward,speed,ms,N,hidden size,512,0.008191999979317188,0.008191999979317188,0.009216000325977802,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:16,0.5.8 +softmax,liger,forward,speed,ms,N,hidden size,1024,0.013311999849975109,0.01228800043463707,0.013311999849975109,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:16,0.5.8 +softmax,liger,forward,speed,ms,N,hidden size,2048,0.02252800017595291,0.02252800017595291,0.023552000522613525,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:16,0.5.8 +softmax,liger,forward,speed,ms,N,hidden size,4096,0.04095999896526337,0.04095999896526337,0.04198399931192398,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:16,0.5.8 +softmax,torch,forward,speed,ms,N,hidden size,128,0.006144000217318535,0.005119999870657921,0.006144000217318535,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:19,0.5.8 +softmax,torch,forward,speed,ms,N,hidden size,256,0.006207999773323536,0.006144000217318535,0.0071680000983178616,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:19,0.5.8 +softmax,torch,forward,speed,ms,N,hidden size,512,0.008383999578654766,0.008191999979317188,0.009216000325977802,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:19,0.5.8 +softmax,torch,forward,speed,ms,N,hidden size,1024,0.014336000196635723,0.014336000196635723,0.014336000196635723,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:19,0.5.8 +softmax,torch,forward,speed,ms,N,hidden size,2048,0.05939200147986412,0.058368001133203506,0.05939200147986412,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:19,0.5.8 +softmax,torch,forward,speed,ms,N,hidden size,4096,0.06758400052785873,0.06675200164318085,0.06758400052785873,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:19,0.5.8 +softmax,liger,full,speed,ms,N,hidden size,128,0.11472000181674957,0.09744639694690704,0.20684799551963806,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:22,0.5.8 +softmax,liger,full,speed,ms,N,hidden size,256,0.15787199139595032,0.10769280046224594,0.20897281169891357,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:22,0.5.8 +softmax,liger,full,speed,ms,N,hidden size,512,0.14028799533843994,0.0832064226269722,0.2879999876022339,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:22,0.5.8 +softmax,liger,full,speed,ms,N,hidden size,1024,0.2088959962129593,0.11446399986743927,0.2972480058670044,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:22,0.5.8 +softmax,liger,full,speed,ms,N,hidden size,2048,0.1443839967250824,0.09318400174379349,0.28278398513793945,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:22,0.5.8 +softmax,liger,full,speed,ms,N,hidden size,4096,0.11673600226640701,0.10035199671983719,0.28074881434440613,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:22,0.5.8 +softmax,torch,full,speed,ms,N,hidden size,128,0.011264000087976456,0.010239999741315842,0.011264000087976456,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:25,0.5.8 +softmax,torch,full,speed,ms,N,hidden size,256,0.013311999849975109,0.013311999849975109,0.013632000423967838,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:25,0.5.8 +softmax,torch,full,speed,ms,N,hidden size,512,0.01945599913597107,0.01945599913597107,0.01945599913597107,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:25,0.5.8 +softmax,torch,full,speed,ms,N,hidden size,1024,0.04198399931192398,0.04198399931192398,0.04224000126123428,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:25,0.5.8 +softmax,torch,full,speed,ms,N,hidden size,2048,0.12595200538635254,0.12595200538635254,0.12697599828243256,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:25,0.5.8 +softmax,torch,full,speed,ms,N,hidden size,4096,0.19763199985027313,0.19660800695419312,0.19809921085834503,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:25,0.5.8 +softmax,liger,full,memory,MB,N,hidden size,128,0.00244140625,0.00244140625,0.00244140625,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:25,0.5.8 +softmax,liger,full,memory,MB,N,hidden size,256,0.0048828125,0.0048828125,0.0048828125,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:25,0.5.8 +softmax,liger,full,memory,MB,N,hidden size,512,0.009765625,0.009765625,0.009765625,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:25,0.5.8 +softmax,liger,full,memory,MB,N,hidden size,1024,0.01953125,0.01953125,0.01953125,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:25,0.5.8 +softmax,liger,full,memory,MB,N,hidden size,2048,0.0390625,0.0390625,0.0390625,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:25,0.5.8 +softmax,liger,full,memory,MB,N,hidden size,4096,0.078125,0.078125,0.078125,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:25,0.5.8 +softmax,torch,full,memory,MB,N,hidden size,128,0.0029296875,0.0029296875,0.0029296875,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:25,0.5.8 +softmax,torch,full,memory,MB,N,hidden size,256,0.005859375,0.005859375,0.005859375,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:25,0.5.8 +softmax,torch,full,memory,MB,N,hidden size,512,0.01171875,0.01171875,0.01171875,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:25,0.5.8 +softmax,torch,full,memory,MB,N,hidden size,1024,0.0234375,0.0234375,0.0234375,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:25,0.5.8 +softmax,torch,full,memory,MB,N,hidden size,2048,0.046875,0.046875,0.046875,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:25,0.5.8 +softmax,torch,full,memory,MB,N,hidden size,4096,0.09375,0.09375,0.09375,"{""M"": 2048, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:25,0.5.8 +softmax,liger,full,memory,MB,N,hidden size,128,0.00244140625,0.00244140625,0.00244140625,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:25,0.5.8 +softmax,liger,full,memory,MB,N,hidden size,256,0.00244140625,0.00244140625,0.00244140625,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:25,0.5.8 +softmax,liger,full,memory,MB,N,hidden size,512,0.0048828125,0.0048828125,0.0048828125,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:25,0.5.8 +softmax,liger,full,memory,MB,N,hidden size,1024,0.009765625,0.009765625,0.009765625,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:25,0.5.8 +softmax,liger,full,memory,MB,N,hidden size,2048,0.01953125,0.01953125,0.01953125,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:25,0.5.8 +softmax,liger,full,memory,MB,N,hidden size,4096,0.0390625,0.0390625,0.0390625,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:25,0.5.8 +softmax,torch,full,memory,MB,N,hidden size,128,0.0029296875,0.0029296875,0.0029296875,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:26,0.5.8 +softmax,torch,full,memory,MB,N,hidden size,256,0.0029296875,0.0029296875,0.0029296875,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:26,0.5.8 +softmax,torch,full,memory,MB,N,hidden size,512,0.005859375,0.005859375,0.005859375,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:26,0.5.8 +softmax,torch,full,memory,MB,N,hidden size,1024,0.01171875,0.01171875,0.01171875,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:26,0.5.8 +softmax,torch,full,memory,MB,N,hidden size,2048,0.0234375,0.0234375,0.0234375,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:26,0.5.8 +softmax,torch,full,memory,MB,N,hidden size,4096,0.046875,0.046875,0.046875,"{""M"": 2048, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 3090,2025-04-30 16:11:26,0.5.8 +sparse_multi_token_attention,liger,forward,speed,ms,L,sequence length,32,0.31436800956726074,0.30646398663520813,0.319487988948822,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:12,0.5.8 +sparse_multi_token_attention,liger,forward,speed,ms,L,sequence length,64,0.3779039978981018,0.3678207993507385,0.38410240411758423,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:12,0.5.8 +sparse_multi_token_attention,liger,forward,speed,ms,L,sequence length,128,0.35020801424980164,0.3428351879119873,0.35839998722076416,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:12,0.5.8 +sparse_multi_token_attention,liger,forward,speed,ms,L,sequence length,256,0.5294079780578613,0.5283839702606201,0.5304319858551025,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:12,0.5.8 +sparse_multi_token_attention,liger,forward,speed,ms,L,sequence length,512,1.7315839529037476,1.7304960489273071,1.815551996231079,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:12,0.5.8 +sparse_multi_token_attention,liger,forward,speed,ms,L,sequence length,1024,6.465375900268555,6.462463855743408,6.718054294586182,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:12,0.5.8 +sparse_multi_token_attention,torch,forward,speed,ms,L,sequence length,32,0.5888000130653381,0.5826560258865356,0.5960000157356262,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:13,0.5.8 +sparse_multi_token_attention,torch,forward,speed,ms,L,sequence length,64,0.6010879874229431,0.5947520136833191,0.608128011226654,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:13,0.5.8 +sparse_multi_token_attention,torch,forward,speed,ms,L,sequence length,128,0.5816320180892944,0.5745791792869568,0.5908480286598206,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:13,0.5.8 +sparse_multi_token_attention,torch,forward,speed,ms,L,sequence length,256,0.8591359853744507,0.8529919981956482,0.8627520203590393,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:13,0.5.8 +sparse_multi_token_attention,torch,forward,speed,ms,L,sequence length,512,1.931391954421997,1.925772786140442,1.935705542564392,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:13,0.5.8 +sparse_multi_token_attention,torch,forward,speed,ms,L,sequence length,1024,6.76915168762207,6.761676788330078,7.009791851043701,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:13,0.5.8 +sparse_multi_token_attention,liger,full,speed,ms,L,sequence length,32,2.111056089401245,2.0716030597686768,2.137094497680664,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:16,0.5.8 +sparse_multi_token_attention,liger,full,speed,ms,L,sequence length,64,2.174975872039795,2.1364736557006836,2.297856092453003,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:16,0.5.8 +sparse_multi_token_attention,liger,full,speed,ms,L,sequence length,128,2.0894718170166016,2.073791980743408,2.1352319717407227,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:16,0.5.8 +sparse_multi_token_attention,liger,full,speed,ms,L,sequence length,256,2.137216091156006,1.8400319814682007,2.194175958633423,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:16,0.5.8 +sparse_multi_token_attention,liger,full,speed,ms,L,sequence length,512,2.2814719676971436,2.1872639656066895,2.2833151817321777,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:16,0.5.8 +sparse_multi_token_attention,liger,full,speed,ms,L,sequence length,1024,8.308735847473145,8.299519538879395,8.551424026489258,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:16,0.5.8 +sparse_multi_token_attention,torch,full,speed,ms,L,sequence length,32,1.5749119520187378,1.498412847518921,2.170527935028076,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:17,0.5.8 +sparse_multi_token_attention,torch,full,speed,ms,L,sequence length,64,1.494047999382019,1.482604742050171,1.5207936763763428,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:17,0.5.8 +sparse_multi_token_attention,torch,full,speed,ms,L,sequence length,128,1.4581760168075562,1.4419968128204346,2.1133759021759033,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:17,0.5.8 +sparse_multi_token_attention,torch,full,speed,ms,L,sequence length,256,1.7448960542678833,1.7180671691894531,1.7537024021148682,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:17,0.5.8 +sparse_multi_token_attention,torch,full,speed,ms,L,sequence length,512,2.796544075012207,2.7762560844421387,2.8190720081329346,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:17,0.5.8 +sparse_multi_token_attention,torch,full,speed,ms,L,sequence length,1024,9.511823654174805,9.501286506652832,9.787391662597656,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:17,0.5.8 +sparse_multi_token_attention,liger,backward,speed,ms,L,sequence length,32,0.3544960021972656,0.33546239137649536,0.8041215538978577,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:18,0.5.8 +sparse_multi_token_attention,liger,backward,speed,ms,L,sequence length,64,0.32897597551345825,0.32051199674606323,0.3438591957092285,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:18,0.5.8 +sparse_multi_token_attention,liger,backward,speed,ms,L,sequence length,128,0.30931198596954346,0.3002240061759949,0.3197120130062103,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:18,0.5.8 +sparse_multi_token_attention,liger,backward,speed,ms,L,sequence length,256,0.31334400177001953,0.2956160008907318,0.3251904249191284,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:18,0.5.8 +sparse_multi_token_attention,liger,backward,speed,ms,L,sequence length,512,0.447488009929657,0.44646400213241577,0.4485119879245758,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:18,0.5.8 +sparse_multi_token_attention,liger,backward,speed,ms,L,sequence length,1024,1.8585599660873413,1.8574656248092651,1.861631989479065,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:18,0.5.8 +sparse_multi_token_attention,torch,backward,speed,ms,L,sequence length,32,0.25804799795150757,0.24883200228214264,0.30926719307899475,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:18,0.5.8 +sparse_multi_token_attention,torch,backward,speed,ms,L,sequence length,64,0.25804799795150757,0.2514623999595642,0.26668161153793335,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:18,0.5.8 +sparse_multi_token_attention,torch,backward,speed,ms,L,sequence length,128,0.24075199663639069,0.2303999960422516,0.25194239616394043,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:18,0.5.8 +sparse_multi_token_attention,torch,backward,speed,ms,L,sequence length,256,0.24686399102210999,0.23756800591945648,0.2550272047519684,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:18,0.5.8 +sparse_multi_token_attention,torch,backward,speed,ms,L,sequence length,512,0.7045120000839233,0.704479992389679,0.7063615918159485,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:18,0.5.8 +sparse_multi_token_attention,torch,backward,speed,ms,L,sequence length,1024,2.698431968688965,2.697216033935547,2.7013120651245117,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:18,0.5.8 +sparse_multi_token_attention,liger,full,memory,MB,L,sequence length,32,0.3603515625,0.3603515625,0.3603515625,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:19,0.5.8 +sparse_multi_token_attention,liger,full,memory,MB,L,sequence length,64,1.4189453125,1.4189453125,1.4189453125,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:19,0.5.8 +sparse_multi_token_attention,liger,full,memory,MB,L,sequence length,128,5.6455078125,5.6455078125,5.6455078125,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:19,0.5.8 +sparse_multi_token_attention,liger,full,memory,MB,L,sequence length,256,22.53662109375,22.53662109375,22.53662109375,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:19,0.5.8 +sparse_multi_token_attention,liger,full,memory,MB,L,sequence length,512,90.06884765625,90.06884765625,90.06884765625,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:19,0.5.8 +sparse_multi_token_attention,liger,full,memory,MB,L,sequence length,1024,360.13330078125,360.13330078125,360.13330078125,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:19,0.5.8 +sparse_multi_token_attention,torch,full,memory,MB,L,sequence length,32,0.45263671875,0.45263671875,0.45263671875,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:19,0.5.8 +sparse_multi_token_attention,torch,full,memory,MB,L,sequence length,64,1.7685546875,1.7685546875,1.7685546875,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:19,0.5.8 +sparse_multi_token_attention,torch,full,memory,MB,L,sequence length,128,7.04833984375,7.04833984375,7.04833984375,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:19,0.5.8 +sparse_multi_token_attention,torch,full,memory,MB,L,sequence length,256,28.15478515625,28.15478515625,28.15478515625,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:19,0.5.8 +sparse_multi_token_attention,torch,full,memory,MB,L,sequence length,512,112.55517578125,112.55517578125,112.55517578125,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:19,0.5.8 +sparse_multi_token_attention,torch,full,memory,MB,L,sequence length,1024,450.10595703125,450.10595703125,450.10595703125,"{""B"": 2, ""C_in"": 4, ""C_out"": 4, ""K"": 3, ""groups"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-30 17:22:19,0.5.8 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,64,0.236735999584198,0.16073599457740784,0.24985599517822266,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:08:54,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,128,0.22323200106620789,0.21503999829292297,0.2323904037475586,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:08:54,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,256,0.24268800020217896,0.2295808047056198,0.25088000297546387,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:08:54,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,512,0.3307519853115082,0.32805120944976807,0.3317759931087494,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:08:54,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,1024,0.8540160059928894,0.851967990398407,0.8595455884933472,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:08:54,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,2048,2.3658719062805176,2.3617537021636963,2.368511915206909,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:08:54,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,4096,8.466431617736816,8.447999954223633,8.480768203735352,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:08:54,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,64,5.16915225982666,5.143871784210205,5.297952175140381,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:01,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,128,10.244048118591309,10.094131469726562,10.48145866394043,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:01,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,256,20.196895599365234,20.145601272583008,21.581132888793945,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:01,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,512,42.183536529541016,41.2415771484375,43.12549591064453,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:01,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,1024,77.73798370361328,77.73798370361328,77.73798370361328,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:01,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,2048,172.90853881835938,172.90853881835938,172.90853881835938,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:01,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,4096,346.5686950683594,346.5686950683594,346.5686950683594,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:01,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,64,2.723423957824707,2.68287992477417,2.7842559814453125,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:14,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,128,2.6542398929595947,2.6169726848602295,2.68984317779541,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:14,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,256,2.595871925354004,2.1286911964416504,2.6818559169769287,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:14,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,512,2.738736152648926,2.7115519046783447,2.8180480003356934,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:14,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,1024,2.83457612991333,2.805759906768799,2.88972806930542,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:14,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,2048,6.529168128967285,6.525951862335205,6.66664981842041,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:14,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,4096,23.742895126342773,23.660747528076172,23.825515747070312,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:14,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,64,6.841343879699707,6.725196838378906,6.972832202911377,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:21,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,128,11.825152397155762,11.683839797973633,12.080537796020508,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:21,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,256,21.856351852416992,21.36012077331543,21.95940589904785,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:21,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,512,42.70033264160156,42.545169830322266,42.855499267578125,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:21,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,1024,87.9656982421875,87.9656982421875,87.9656982421875,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:21,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,2048,181.77536010742188,181.77536010742188,181.77536010742188,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:21,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,4096,368.0634765625,368.0634765625,368.0634765625,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:21,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,64,0.5920320153236389,0.5674688220024109,1.3856768608093262,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:22,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,128,0.6430720090866089,0.6318399906158447,0.6610943675041199,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:22,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,256,0.6456320285797119,0.6359040141105652,0.6676480174064636,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:22,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,512,0.7014399766921997,0.6911231875419617,0.7275007963180542,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:22,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,1024,1.4684159755706787,1.4663679599761963,1.4704639911651611,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:22,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,2048,4.150223731994629,4.14717435836792,4.234445095062256,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:22,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,4096,15.17465591430664,14.853119850158691,15.310848236083984,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:22,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,64,0.6000639796257019,0.5832703709602356,1.2799999713897705,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:25,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,128,0.5550079941749573,0.5488640069961548,0.5914624333381653,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:25,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,256,0.5470079779624939,0.5406720042228699,0.562175989151001,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:25,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,512,0.8714240193367004,0.8617984056472778,1.2751424312591553,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:25,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,1024,2.3746559619903564,2.3727169036865234,2.3797760009765625,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:25,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,2048,8.019968032836914,8.00870418548584,8.2227201461792,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:25,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,4096,28.92291259765625,28.684505462646484,28.97941780090332,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:25,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,64,0.23756800591945648,0.22630399465560913,0.24985599517822266,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:32,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,128,0.25088000297546387,0.24187520146369934,0.25964802503585815,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:32,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,256,0.43110400438308716,0.42920318245887756,0.43212801218032837,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:32,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,512,1.0199040174484253,1.0147839784622192,1.0281280279159546,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:32,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,1024,2.584575891494751,2.578432083129883,2.593791961669922,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:32,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,2048,7.8611040115356445,7.851212978363037,8.14100456237793,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:32,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,4096,27.072511672973633,27.043020248413086,27.129650115966797,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:32,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,64,5.303808212280273,5.205196857452393,5.414611339569092,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:38,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,128,10.352640151977539,10.268671989440918,10.546982765197754,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:38,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,256,20.696575164794922,20.600217819213867,22.168373107910156,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:38,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,512,40.9251823425293,39.459224700927734,42.39113998413086,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:38,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,1024,84.20972442626953,84.20972442626953,84.20972442626953,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:38,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,2048,165.5727996826172,165.5727996826172,165.5727996826172,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:38,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,4096,365.4942626953125,365.4942626953125,365.4942626953125,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:38,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,64,2.5410561561584473,2.5221376419067383,2.574540853500366,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:52,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,128,2.6214399337768555,2.5966720581054688,2.66780161857605,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:52,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,256,2.6818559169769287,2.660710334777832,2.7396223545074463,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:52,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,512,2.9624319076538086,2.959359884262085,2.973695993423462,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:52,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,1024,7.516160011291504,7.5141119956970215,7.782809734344482,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:52,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,2048,22.99033546447754,22.859058380126953,23.101655960083008,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:52,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,4096,79.14390563964844,79.14390563964844,79.14390563964844,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:09:52,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,64,6.206463813781738,6.177548885345459,6.346368312835693,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:00,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,128,11.45395278930664,11.369497299194336,11.57201862335205,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:00,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,256,21.295616149902344,20.8918514251709,22.428876876831055,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:00,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,512,46.485904693603516,44.799137115478516,48.172672271728516,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:00,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,1024,87.60115051269531,87.60115051269531,87.60115051269531,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:00,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,2048,210.36146545410156,210.36146545410156,210.36146545410156,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:00,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,4096,456.848388671875,456.848388671875,456.848388671875,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:00,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,64,0.5756800174713135,0.45319682359695435,0.7064127922058105,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:02,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,128,0.5908480286598206,0.48742398619651794,0.6028479933738708,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:02,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,256,0.915615975856781,0.8775680065155029,0.9175040125846863,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:02,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,512,1.9450880289077759,1.9351999759674072,1.9651199579238892,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:02,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,1024,4.930560111999512,4.915200233459473,5.046477317810059,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:02,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,2048,15.102832794189453,14.952447891235352,15.31494426727295,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:02,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,4096,52.104190826416016,52.104190826416016,52.104190826416016,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:02,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,64,0.4843519926071167,0.4761984050273895,0.6077119708061218,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:05,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,128,0.5319839715957642,0.5222399830818176,0.5335040092468262,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:05,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,256,1.1182080507278442,1.1151360273361206,1.120255947113037,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:05,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,512,2.5815041065216064,2.5763840675354004,2.5960447788238525,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:05,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,1024,7.123968124389648,7.087513446807861,7.359897613525391,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:05,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,2048,24.104448318481445,24.077312469482422,24.161880493164062,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:05,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,4096,86.40716552734375,86.40716552734375,86.40716552734375,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:05,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,64,0.2467840015888214,0.17902079224586487,0.25702399015426636,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:12,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,128,0.23756800591945648,0.23654399812221527,0.24885760247707367,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:12,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,256,0.4567039906978607,0.45158401131629944,0.4638719856739044,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:12,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,512,0.8017920255661011,0.7946239709854126,0.8048639893531799,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:12,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,1024,1.9527679681777954,1.9476544857025146,1.9595264196395874,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:12,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,2048,5.405695915222168,5.392384052276611,5.651423931121826,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:12,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,4096,18.608959197998047,18.311372756958008,18.646629333496094,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:12,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,64,6.554111957550049,6.130688190460205,6.872096061706543,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:20,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,128,13.195263862609863,13.134265899658203,13.464166641235352,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:20,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,256,24.001535415649414,23.594995498657227,25.934438705444336,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:20,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,512,50.334720611572266,50.334720611572266,50.334720611572266,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:20,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,1024,107.2701416015625,107.2701416015625,107.2701416015625,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:20,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,2048,218.13658142089844,218.13658142089844,218.13658142089844,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:20,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,4096,457.2313537597656,457.2313537597656,457.2313537597656,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:20,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,64,2.623487949371338,2.605638265609741,2.6442177295684814,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:34,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,128,2.6389598846435547,2.6225087642669678,2.6781694889068604,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:34,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,256,2.613312005996704,2.589139223098755,2.6998207569122314,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:34,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,512,2.7299840450286865,2.7037951946258545,2.783027172088623,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:34,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,1024,5.588992118835449,5.584896087646484,5.632409572601318,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:34,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,2048,15.91859245300293,15.853568077087402,16.029695510864258,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:34,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,4096,54.28019332885742,54.28019332885742,54.28019332885742,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:34,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,64,8.281087875366211,8.076288223266602,8.5731840133667,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:43,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,128,14.909952163696289,14.721952438354492,15.562975883483887,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:43,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,256,25.10848045349121,25.013248443603516,25.180980682373047,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:43,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,512,53.98118209838867,53.98118209838867,53.98118209838867,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:43,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,1024,115.51538848876953,115.51538848876953,115.51538848876953,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:43,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,2048,234.2144012451172,234.2144012451172,234.2144012451172,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:43,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,4096,493.1143798828125,493.1143798828125,493.1143798828125,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:43,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,64,0.6873279809951782,0.6780927777290344,0.8112127780914307,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:45,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,128,0.6923519968986511,0.6756608486175537,0.8371520042419434,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:45,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,256,0.7854080200195312,0.7739391922950745,0.7946239709854126,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:45,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,512,1.5523840188980103,1.5431679487228394,1.5880192518234253,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:45,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,1024,3.635200023651123,3.634176015853882,3.637446403503418,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:45,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,2048,10.225664138793945,10.196991920471191,10.515456199645996,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:45,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,4096,35.736061096191406,35.612876892089844,35.859249114990234,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:45,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,64,0.4935680031776428,0.4843519926071167,1.2861696481704712,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:48,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,128,0.5950400233268738,0.4885439872741699,0.7454720735549927,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:48,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,256,0.9082880020141602,0.8939520120620728,1.2302591800689697,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:48,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,512,1.994752049446106,1.9916800260543823,2.002943992614746,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:48,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,1024,5.427199840545654,5.400953769683838,5.5943169593811035,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:48,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,2048,16.917503356933594,16.85626792907715,17.202789306640625,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:48,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,4096,58.775550842285156,58.775550842285156,58.775550842285156,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:48,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,64,0.16998399794101715,0.159743994474411,0.24968959391117096,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:52,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,128,0.15515199303627014,0.14643199741840363,0.16281600296497345,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:52,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,256,0.16998399794101715,0.159743994474411,0.25088000297546387,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:52,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,512,0.3307519853115082,0.32767999172210693,0.3317759931087494,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:52,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,1024,0.8550400137901306,0.8529919981956482,0.8581119775772095,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:52,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,2048,2.3664638996124268,2.36456298828125,2.371583938598633,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:52,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,4096,8.253439903259277,8.21452808380127,8.534015655517578,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:52,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,64,5.056511878967285,4.674380779266357,5.254271984100342,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:58,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,128,10.41360092163086,10.147839546203613,10.88619613647461,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:58,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,256,21.108095169067383,19.98341178894043,22.000703811645508,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:58,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,512,39.93907165527344,39.49793243408203,40.380210876464844,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:58,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,1024,87.47724914550781,87.47724914550781,87.47724914550781,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:58,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,2048,162.8107147216797,162.8107147216797,162.8107147216797,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:58,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,4096,318.89202880859375,318.89202880859375,318.89202880859375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:58,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,64,2.756608009338379,2.50598406791687,2.862694263458252,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:59,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,128,2.683903932571411,2.656268835067749,2.720358371734619,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:59,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,256,2.6729280948638916,2.649907112121582,2.703104019165039,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:59,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,512,2.8049919605255127,2.7712254524230957,2.848358392715454,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:59,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,1024,2.8816640377044678,2.8426239490509033,2.966118335723877,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:59,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,2048,6.523903846740723,6.52185583114624,6.534143924713135,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:59,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,4096,23.48236846923828,23.36788558959961,23.587430953979492,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:10:59,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,64,6.210592269897461,6.149964809417725,6.439935684204102,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:06,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,128,11.412479400634766,11.000422477722168,12.122776985168457,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:06,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,256,21.02124786376953,20.722354888916016,21.280357360839844,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:06,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,512,44.49420928955078,43.21909713745117,45.769317626953125,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:06,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,1024,77.97862243652344,77.97862243652344,77.97862243652344,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:06,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,2048,169.87033081054688,169.87033081054688,169.87033081054688,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:06,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,4096,360.7623596191406,360.7623596191406,360.7623596191406,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:06,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,64,0.6484479904174805,0.5443072319030762,1.446675181388855,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:07,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,128,0.5460799932479858,0.536575973033905,0.6473984122276306,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:07,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,256,0.5612640380859375,0.5377407670021057,0.6634495854377747,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:07,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,512,0.6347839832305908,0.6327999830245972,0.7219520211219788,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:07,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,1024,1.4684159755706787,1.4624768495559692,1.4744960069656372,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:07,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,2048,4.150784015655518,4.148223876953125,4.164403438568115,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:07,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,4096,15.233535766601562,14.96678352355957,15.318016052246094,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:07,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,64,0.596992015838623,0.5801728367805481,1.2581120729446411,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:10,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,128,0.5565760135650635,0.456928014755249,0.5724160075187683,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:10,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,256,0.5560640096664429,0.4616512060165405,0.5724160075187683,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:10,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,512,0.8714240193367004,0.8622080087661743,1.2775424718856812,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:10,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,1024,2.3746559619903564,2.371583938598633,2.3776895999908447,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:10,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,2048,8.032719612121582,8.015257835388184,8.314061164855957,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:10,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,4096,29.113344192504883,28.672204971313477,29.20366096496582,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:10,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,64,32.525390625,32.525390625,32.525390625,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:10,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,128,37.7734375,37.7734375,37.7734375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:10,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,256,53.2734375,53.2734375,53.2734375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:10,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,512,102.2734375,102.2734375,102.2734375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:10,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,1024,272.2734375,272.2734375,272.2734375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:10,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,2048,900.2734375,900.2734375,900.2734375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:10,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,4096,3308.2734375,3308.2734375,3308.2734375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:10,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,64,32.53125,32.53125,32.53125,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:17,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,128,36.8046875,36.8046875,36.8046875,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:17,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,256,53.3359375,53.3359375,53.3359375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:17,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,512,110.5234375,110.5234375,110.5234375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:17,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,1024,321.2734375,321.2734375,321.2734375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:17,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,2048,1128.2734375,1128.2734375,1128.2734375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:17,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,4096,4284.2734375,4284.2734375,4284.2734375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:17,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,64,55.2880859375,55.2880859375,55.2880859375,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:18,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,128,72.28515625,72.28515625,72.28515625,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:18,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,256,119.03515625,119.03515625,119.03515625,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:18,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,512,265.28515625,265.28515625,265.28515625,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:18,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,1024,775.28515625,775.28515625,775.28515625,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:18,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,2048,2659.28515625,2659.28515625,2659.28515625,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:18,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,4096,9883.28515625,9883.28515625,9883.28515625,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:18,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,64,55.2919921875,55.2919921875,55.2919921875,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:27,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,128,70.05078125,70.05078125,70.05078125,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:27,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,256,118.34765625,118.34765625,118.34765625,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:27,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,512,289.53515625,289.53515625,289.53515625,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:27,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,1024,920.28515625,920.28515625,920.28515625,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:27,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,2048,3335.28515625,3335.28515625,3335.28515625,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:27,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,4096,12779.28515625,12779.28515625,12779.28515625,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:27,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,64,74.80078125,74.80078125,74.80078125,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:28,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,128,83.296875,83.296875,83.296875,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:28,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,256,114.296875,114.296875,114.296875,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:28,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,512,212.296875,212.296875,212.296875,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:28,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,1024,552.296875,552.296875,552.296875,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:28,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,2048,1808.296875,1808.296875,1808.296875,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:28,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,4096,6624.296875,6624.296875,6624.296875,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:28,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,64,74.8046875,74.8046875,74.8046875,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:38,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,128,82.31640625,82.31640625,82.31640625,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:38,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,256,114.359375,114.359375,114.359375,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:38,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,512,228.546875,228.546875,228.546875,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:38,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,1024,649.296875,649.296875,649.296875,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:38,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,2048,2260.296875,2260.296875,2260.296875,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:38,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,4096,8560.296875,8560.296875,8560.296875,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:38,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,64,32.525390625,32.525390625,32.525390625,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:39,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,128,37.7734375,37.7734375,37.7734375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:39,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,256,53.2734375,53.2734375,53.2734375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:39,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,512,102.2734375,102.2734375,102.2734375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:39,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,1024,272.2734375,272.2734375,272.2734375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:39,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,2048,900.2734375,900.2734375,900.2734375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:39,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,4096,3308.2734375,3308.2734375,3308.2734375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:39,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,64,32.53125,32.53125,32.53125,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:46,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,128,36.8046875,36.8046875,36.8046875,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:46,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,256,53.3359375,53.3359375,53.3359375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:46,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,512,110.5234375,110.5234375,110.5234375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:46,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,1024,321.2734375,321.2734375,321.2734375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:46,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,2048,1128.2734375,1128.2734375,1128.2734375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:46,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,4096,4284.2734375,4284.2734375,4284.2734375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-27 15:11:46,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,64,0.25600001215934753,0.25436800718307495,0.2605184018611908,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:08,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,128,0.2569279968738556,0.25494399666786194,0.26105600595474243,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:08,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,256,0.25676798820495605,0.2550591826438904,0.2598848044872284,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:08,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,512,0.25841599702835083,0.25681281089782715,0.2625727951526642,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:08,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,1024,0.3150399923324585,0.31407999992370605,0.31611520051956177,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:08,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,2048,0.8260959982872009,0.8238016366958618,0.828614354133606,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:08,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,4096,2.5686399936676025,2.557523012161255,2.5757951736450195,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:08,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,64,5.276463985443115,5.270419120788574,5.286643028259277,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:14,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,128,10.498432159423828,10.476134300231934,10.51439380645752,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:14,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,256,20.82036781311035,20.771360397338867,20.881420135498047,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:14,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,512,42.07323455810547,41.776065826416016,42.370399475097656,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:14,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,1024,81.8509750366211,81.8509750366211,81.8509750366211,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:14,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,2048,165.88720703125,165.88720703125,165.88720703125,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:14,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,4096,331.2662658691406,331.2662658691406,331.2662658691406,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:14,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,64,0.8993600010871887,0.8924031853675842,0.9097279906272888,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:25,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,128,0.8939200043678284,0.8890752196311951,0.9034687876701355,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:25,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,256,0.9244480133056641,0.9180480241775513,0.940447986125946,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:25,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,512,0.9229600429534912,0.915289580821991,0.9307839870452881,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:25,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,1024,0.9950560331344604,0.9915199875831604,0.9971520304679871,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:25,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,2048,2.5537919998168945,2.548985481262207,2.5564353466033936,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:25,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,4096,7.698319911956787,7.67669153213501,7.713951587677002,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:25,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,64,5.840767860412598,5.819551944732666,5.864096164703369,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:31,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,128,11.064079284667969,11.050003051757812,11.102252960205078,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:31,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,256,21.443504333496094,21.364646911621094,21.61541748046875,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:31,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,512,42.16088104248047,42.137290954589844,42.18446731567383,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:31,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,1024,84.43017578125,84.43017578125,84.43017578125,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:31,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,2048,169.27821350097656,169.27821350097656,169.27821350097656,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:31,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,4096,342.5223388671875,342.5223388671875,342.5223388671875,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:31,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,64,0.49110400676727295,0.4891200065612793,0.49513599276542664,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:32,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,128,0.4911839962005615,0.4894847869873047,0.4949440062046051,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:32,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,256,0.5103520154953003,0.5084800124168396,0.5146496295928955,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:32,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,512,0.5199040174484253,0.5182399749755859,0.5254335999488831,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:32,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,1024,0.6806079745292664,0.6792960166931152,0.681990385055542,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:32,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,2048,1.7373919486999512,1.7352639436721802,1.7395071983337402,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:32,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,4096,5.2151360511779785,5.205132484436035,5.221510410308838,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:32,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,64,0.4123840034008026,0.41091200709342957,0.4163135886192322,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:35,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,128,0.4136800169944763,0.41203200817108154,0.4168703854084015,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:35,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,256,0.4320639967918396,0.4301888048648834,0.4355071783065796,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:35,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,512,0.44307199120521545,0.44010239839553833,0.4480448067188263,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:35,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,1024,0.9624000191688538,0.9609023928642273,0.9633920192718506,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:35,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,2048,2.6429600715637207,2.641439914703369,2.644223928451538,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:35,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,4096,8.974464416503906,8.973376274108887,8.97913646697998,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:35,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,64,0.2598559856414795,0.2580096125602722,0.2628991901874542,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:40,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,128,0.2602880001068115,0.25900799036026,0.26241281628608704,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:40,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,256,0.2643519937992096,0.2627519965171814,0.26796799898147583,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:40,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,512,0.41286399960517883,0.4122239947319031,0.4134399890899658,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:40,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,1024,0.9781439900398254,0.9763264060020447,0.9801728129386902,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:40,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,2048,2.659600019454956,2.655103921890259,2.6648640632629395,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:40,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,4096,8.184944152832031,8.175705909729004,8.197542190551758,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:40,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,64,5.3048319816589355,5.287481784820557,5.315853118896484,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:47,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,128,10.493408203125,10.434623718261719,10.539365768432617,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:47,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,256,20.872079849243164,20.860185623168945,21.320632934570312,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:47,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,512,41.84241485595703,41.80018615722656,41.884647369384766,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:47,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,1024,84.96883392333984,84.96883392333984,84.96883392333984,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:47,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,2048,169.7915802001953,169.7915802001953,169.7915802001953,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:47,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,4096,345.4809265136719,345.4809265136719,345.4809265136719,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:47,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,64,0.9144960045814514,0.9068800210952759,0.9251199960708618,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:56,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,128,0.9177280068397522,0.9107391834259033,0.9262208342552185,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:56,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,256,0.9360480308532715,0.9290496110916138,0.949785590171814,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:56,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,512,1.2921760082244873,1.289574384689331,1.2943040132522583,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:56,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,1024,2.9243199825286865,2.919097423553467,2.9282751083374023,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:56,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,2048,7.83568000793457,7.829171180725098,7.843168258666992,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:56,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,4096,24.4779052734375,24.40936279296875,24.545881271362305,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:25:56,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,64,5.912464141845703,5.879615783691406,5.923999786376953,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:03,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,128,11.05232048034668,11.035250663757324,11.079456329345703,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:03,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,256,21.471296310424805,21.445714950561523,21.49998664855957,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:03,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,512,42.718048095703125,42.69863510131836,42.73746109008789,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:03,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,1024,86.00204467773438,86.00204467773438,86.00204467773438,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:03,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,2048,177.3928985595703,177.3928985595703,177.3928985595703,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:03,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,4096,373.61773681640625,373.61773681640625,373.61773681640625,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:03,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,64,0.5130239725112915,0.5107200145721436,0.5175104141235352,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:05,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,128,0.5187360048294067,0.5168319940567017,0.522816002368927,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:05,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,256,0.5284639596939087,0.5261759757995605,0.5319616198539734,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:05,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,512,0.8799999952316284,0.8791552186012268,0.8812223672866821,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:05,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,1024,1.9606720209121704,1.9588288068771362,1.9625920057296753,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:05,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,2048,5.239616394042969,5.233331203460693,5.246374607086182,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:05,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,4096,16.295886993408203,16.174047470092773,16.315935134887695,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:05,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,64,0.4262079894542694,0.42505601048469543,0.42970240116119385,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:07,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,128,0.43747198581695557,0.43620482087135315,0.4399871826171875,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:07,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,256,0.5542719960212708,0.5531839728355408,0.555072009563446,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:07,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,512,1.0854079723358154,1.0841728448867798,1.0862784385681152,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:07,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,1024,2.6914560794830322,2.6902334690093994,2.6927361488342285,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:07,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,2048,8.072175979614258,8.052319526672363,8.081612586975098,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:07,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,4096,27.25152015686035,27.248275756835938,27.25334358215332,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:07,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,64,0.26579201221466064,0.26371198892593384,0.2690303921699524,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:14,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,128,0.26337599754333496,0.26162558794021606,0.2659648060798645,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:14,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,256,0.264384001493454,0.2627967894077301,0.267276793718338,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:14,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,512,0.3535360097885132,0.3527039885520935,0.3543359935283661,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:14,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,1024,0.7347840070724487,0.7331455945968628,0.7361727952957153,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:14,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,2048,1.8545279502868652,1.850592017173767,1.8574399948120117,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:14,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,4096,5.953392028808594,5.927840232849121,5.962080001831055,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:14,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,64,6.691328048706055,6.674118518829346,6.712192058563232,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:22,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,128,13.332127571105957,13.322579383850098,13.362988471984863,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:22,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,256,26.70470428466797,26.678035736083984,27.087322235107422,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:22,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,512,52.936126708984375,52.936126708984375,52.936126708984375,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:22,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,1024,107.26537322998047,107.26537322998047,107.26537322998047,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:22,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,2048,213.9727020263672,213.9727020263672,213.9727020263672,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:22,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,4096,430.3240966796875,430.3240966796875,430.3240966796875,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:22,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,64,0.912992000579834,0.8976320028305054,0.9327296018600464,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:32,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,128,0.9216639995574951,0.9107776284217834,0.9301823973655701,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:32,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,256,0.915615975856781,0.9078848361968994,0.9261952042579651,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:32,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,512,1.1379199028015137,1.1355520486831665,1.1407424211502075,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:32,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,1024,2.277343988418579,2.268371343612671,2.2814719676971436,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:32,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,2048,5.6143999099731445,5.608166217803955,5.673030376434326,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:32,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,4096,17.534591674804688,17.516069412231445,17.57676124572754,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:32,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,64,7.29852819442749,7.287238597869873,7.318784236907959,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:40,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,128,13.901632308959961,13.893203735351562,13.942361831665039,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:40,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,256,27.261056900024414,27.254297256469727,27.288244247436523,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:40,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,512,54.26707077026367,54.26707077026367,54.26707077026367,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:40,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,1024,108.40013122558594,108.40013122558594,108.40013122558594,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:40,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,2048,220.19622802734375,220.19622802734375,220.19622802734375,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:40,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,4096,453.9944763183594,453.9944763183594,453.9944763183594,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:40,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,64,0.49564799666404724,0.4941760003566742,0.49819520115852356,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:42,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,128,0.5055680274963379,0.5036479830741882,0.5097920298576355,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:42,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,256,0.5073280334472656,0.5049920082092285,0.5109120011329651,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:42,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,512,0.7868000268936157,0.7859584093093872,0.7878463864326477,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:42,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,1024,1.5349119901657104,1.5336960554122925,1.5368640422821045,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:42,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,2048,3.791167974472046,3.787168025970459,3.802060842514038,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:42,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,4096,11.613519668579102,11.596006393432617,11.618464469909668,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:42,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,64,0.41388800740242004,0.412447988986969,0.417279988527298,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:45,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,128,0.42691200971603394,0.42473599314689636,0.4324415922164917,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:45,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,256,0.4886400103569031,0.48771199584007263,0.48993921279907227,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:45,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,512,0.9216960072517395,0.9203839898109436,0.9231168031692505,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:45,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,1024,1.9877119064331055,1.9866175651550293,1.9888639450073242,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:45,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,2048,5.659264087677002,5.653772830963135,5.6628031730651855,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:45,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,4096,18.87718391418457,18.870214462280273,18.878368377685547,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:45,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,64,0.26070401072502136,0.258950412273407,0.26361599564552307,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:49,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,128,0.2584800124168396,0.256985604763031,0.26101118326187134,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:49,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,256,0.25942400097846985,0.25811201333999634,0.2618303894996643,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:49,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,512,0.26097601652145386,0.25948798656463623,0.2640959918498993,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:49,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,1024,0.3149600028991699,0.3140160143375397,0.31593599915504456,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:49,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,2048,0.8244799971580505,0.8216319680213928,0.8271167874336243,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:49,0.5.10 +fused_neighborhood_attention,liger,forward,speed,ms,seq_len,sequence length,4096,2.5662078857421875,2.5587263107299805,2.5770816802978516,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:49,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,64,5.195775985717773,5.172947406768799,5.230342388153076,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:55,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,128,10.488927841186523,10.467231750488281,10.511955261230469,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:55,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,256,21.20012664794922,21.1026554107666,21.275672912597656,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:55,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,512,43.42755126953125,42.99705123901367,43.858055114746094,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:55,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,1024,84.55020904541016,84.55020904541016,84.55020904541016,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:55,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,2048,169.3335418701172,169.3335418701172,169.3335418701172,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:55,0.5.10 +fused_neighborhood_attention,torch,forward,speed,ms,seq_len,sequence length,4096,340.14495849609375,340.14495849609375,340.14495849609375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:55,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,64,0.8945279717445374,0.886732816696167,0.9055423736572266,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:56,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,128,0.8908159732818604,0.8847360014915466,0.8983359932899475,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:56,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,256,0.9086400270462036,0.9012479782104492,0.9151040315628052,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:56,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,512,0.9225280284881592,0.9153919816017151,0.9314560294151306,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:56,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,1024,0.9986559748649597,0.9929599761962891,1.0019199848175049,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:56,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,2048,2.5703680515289307,2.56607985496521,2.574105739593506,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:56,0.5.10 +fused_neighborhood_attention,liger,full,speed,ms,seq_len,sequence length,4096,7.78985595703125,7.7626495361328125,7.792575836181641,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:26:56,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,64,5.764095783233643,5.736550331115723,5.7790656089782715,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:03,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,128,11.027040481567383,11.009875297546387,11.10332202911377,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:03,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,256,21.499038696289062,21.467283248901367,21.521759033203125,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:03,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,512,42.39520263671875,42.34148025512695,42.44892120361328,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:03,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,1024,85.2570571899414,85.2570571899414,85.2570571899414,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:03,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,2048,172.73379516601562,172.73379516601562,172.73379516601562,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:03,0.5.10 +fused_neighborhood_attention,torch,full,speed,ms,seq_len,sequence length,4096,347.4947509765625,347.4947509765625,347.4947509765625,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:03,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,64,0.4941760003566742,0.49265921115875244,0.4977791905403137,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:04,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,128,0.49348801374435425,0.49185919761657715,0.4974527955055237,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:04,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,256,0.5101760029792786,0.5087360143661499,0.5148288011550903,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:04,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,512,0.5200639963150024,0.5186240077018738,0.5237439870834351,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:04,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,1024,0.6887840032577515,0.6859776377677917,0.6903167963027954,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:04,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,2048,1.7373759746551514,1.7341376543045044,1.7395455837249756,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:04,0.5.10 +fused_neighborhood_attention,liger,backward,speed,ms,seq_len,sequence length,4096,5.201104164123535,5.196633815765381,5.208876609802246,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:04,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,64,0.4107840061187744,0.40908798575401306,0.41468799114227295,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:06,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,128,0.4121600091457367,0.4106624126434326,0.4156480133533478,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:06,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,256,0.4296959936618805,0.42847999930381775,0.4339391887187958,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:06,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,512,0.43406400084495544,0.4329279959201813,0.43656960129737854,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:06,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,1024,0.9568639993667603,0.9556096196174622,0.9582463502883911,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:06,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,2048,2.6357598304748535,2.634399890899658,2.6394240856170654,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:06,0.5.10 +fused_neighborhood_attention,torch,backward,speed,ms,seq_len,sequence length,4096,8.944831848144531,8.943455696105957,8.947711944580078,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:06,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,64,80.275390625,80.275390625,80.275390625,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:07,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,128,85.5234375,85.5234375,85.5234375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:07,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,256,101.0234375,101.0234375,101.0234375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:07,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,512,150.0234375,150.0234375,150.0234375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:07,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,1024,320.0234375,320.0234375,320.0234375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:07,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,2048,948.0234375,948.0234375,948.0234375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:07,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,4096,3356.0234375,3356.0234375,3356.0234375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:07,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,64,80.28125,80.28125,80.28125,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:14,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,128,84.5546875,84.5546875,84.5546875,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:14,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,256,101.0859375,101.0859375,101.0859375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:14,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,512,158.2734375,158.2734375,158.2734375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:14,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,1024,369.0234375,369.0234375,369.0234375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:14,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,2048,1176.0234375,1176.0234375,1176.0234375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:14,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,4096,4332.0234375,4332.0234375,4332.0234375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:14,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,64,103.0380859375,103.0380859375,103.0380859375,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:14,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,128,120.78515625,120.78515625,120.78515625,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:14,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,256,166.78515625,166.78515625,166.78515625,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:14,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,512,313.03515625,313.03515625,313.03515625,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:14,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,1024,823.03515625,823.03515625,823.03515625,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:14,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,2048,2707.03515625,2707.03515625,2707.03515625,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:14,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,4096,9931.03515625,9931.03515625,9931.03515625,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:14,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,64,103.0419921875,103.0419921875,103.0419921875,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:22,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,128,117.05078125,117.05078125,117.05078125,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:22,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,256,167.34765625,167.34765625,167.34765625,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:22,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,512,337.28515625,337.28515625,337.28515625,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:22,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,1024,968.03515625,968.03515625,968.03515625,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:22,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,2048,3383.03515625,3383.03515625,3383.03515625,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:22,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,4096,12827.03515625,12827.03515625,12827.03515625,"{""batch_size"": 4, ""hidden_size"": 768, ""num_heads"": 12, ""kernel_size"": 7, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:22,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,64,122.55078125,122.55078125,122.55078125,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:22,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,128,131.046875,131.046875,131.046875,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:22,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,256,162.046875,162.046875,162.046875,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:22,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,512,260.046875,260.046875,260.046875,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:22,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,1024,600.046875,600.046875,600.046875,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:22,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,2048,1856.046875,1856.046875,1856.046875,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:22,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,4096,6672.046875,6672.046875,6672.046875,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:22,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,64,122.5546875,122.5546875,122.5546875,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:32,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,128,130.06640625,130.06640625,130.06640625,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:32,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,256,162.109375,162.109375,162.109375,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:32,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,512,276.296875,276.296875,276.296875,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:32,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,1024,697.046875,697.046875,697.046875,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:32,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,2048,2308.046875,2308.046875,2308.046875,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:32,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,4096,8608.046875,8608.046875,8608.046875,"{""batch_size"": 2, ""hidden_size"": 1024, ""num_heads"": 16, ""kernel_size"": 9, ""dilation"": 1, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:32,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,64,80.275390625,80.275390625,80.275390625,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:32,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,128,85.5234375,85.5234375,85.5234375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:32,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,256,101.0234375,101.0234375,101.0234375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:32,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,512,150.0234375,150.0234375,150.0234375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:32,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,1024,320.0234375,320.0234375,320.0234375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:32,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,2048,948.0234375,948.0234375,948.0234375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:32,0.5.10 +fused_neighborhood_attention,liger,full,memory,MB,seq_len,sequence length,4096,3356.0234375,3356.0234375,3356.0234375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:32,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,64,80.28125,80.28125,80.28125,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:39,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,128,84.5546875,84.5546875,84.5546875,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:39,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,256,101.0859375,101.0859375,101.0859375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:39,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,512,158.2734375,158.2734375,158.2734375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:39,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,1024,369.0234375,369.0234375,369.0234375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:39,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,2048,1176.0234375,1176.0234375,1176.0234375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:39,0.5.10 +fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,4096,4332.0234375,4332.0234375,4332.0234375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:39,0.5.10 +distill_cosine_loss,liger,forward,speed,ms,BT,B x T,1024,13.828096389770508,13.821133041381836,13.885849952697754,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:19:52,0.5.10 +distill_cosine_loss,liger,forward,speed,ms,BT,B x T,2048,27.57427215576172,27.52573432922363,27.579801940917967,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:19:52,0.5.10 +distill_cosine_loss,liger,forward,speed,ms,BT,B x T,4096,54.79423904418945,54.79423904418945,54.79423904418945,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:19:52,0.5.10 +distill_cosine_loss,liger,forward,speed,ms,BT,B x T,8192,109.73490905761719,109.73490905761719,109.73490905761719,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:19:52,0.5.10 +distill_cosine_loss,torch,forward,speed,ms,BT,B x T,1024,16.456703186035156,15.045836448669434,16.761650466918944,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:20:34,0.5.10 +distill_cosine_loss,torch,forward,speed,ms,BT,B x T,2048,29.703168869018555,29.69333839416504,29.71177024841309,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:20:34,0.5.10 +distill_cosine_loss,torch,forward,speed,ms,BT,B x T,4096,59.177982330322266,59.177982330322266,59.177982330322266,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:20:34,0.5.10 +distill_cosine_loss,torch,forward,speed,ms,BT,B x T,8192,118.3815689086914,118.3815689086914,118.3815689086914,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:20:34,0.5.10 +distill_cosine_loss,liger,full,speed,ms,BT,B x T,1024,14.654463768005371,14.63398380279541,14.68006420135498,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:21:16,0.5.10 +distill_cosine_loss,liger,full,speed,ms,BT,B x T,2048,28.274688720703125,28.27284507751465,28.279603958129883,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:21:16,0.5.10 +distill_cosine_loss,liger,full,speed,ms,BT,B x T,4096,55.96672058105469,55.96672058105469,55.96672058105469,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:21:16,0.5.10 +distill_cosine_loss,liger,full,speed,ms,BT,B x T,8192,111.38764953613281,111.38764953613281,111.38764953613281,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:21:16,0.5.10 +distill_cosine_loss,torch,full,speed,ms,BT,B x T,1024,37.45382308959961,37.42556076049805,37.482085418701175,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:22:01,0.5.10 +distill_cosine_loss,torch,full,speed,ms,BT,B x T,2048,73.56620788574219,73.56620788574219,73.56620788574219,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:22:01,0.5.10 +distill_cosine_loss,torch,full,speed,ms,BT,B x T,4096,145.73056030273438,145.73056030273438,145.73056030273438,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:22:01,0.5.10 +distill_cosine_loss,torch,full,speed,ms,BT,B x T,8192,291.5000305175781,291.5000305175781,291.5000305175781,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:22:01,0.5.10 +distill_cosine_loss,liger,full,memory,MB,BT,B x T,1024,5059.26806640625,5059.26806640625,5059.26806640625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:22:43,0.5.10 +distill_cosine_loss,liger,full,memory,MB,BT,B x T,2048,5087.27587890625,5087.27587890625,5087.27587890625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:22:43,0.5.10 +distill_cosine_loss,liger,full,memory,MB,BT,B x T,4096,5143.29150390625,5143.29150390625,5143.29150390625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:22:43,0.5.10 +distill_cosine_loss,liger,full,memory,MB,BT,B x T,8192,5255.32275390625,5255.32275390625,5255.32275390625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:22:43,0.5.10 +distill_cosine_loss,torch,full,memory,MB,BT,B x T,1024,7566.2822265625,7566.2822265625,7566.2822265625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:23:28,0.5.10 +distill_cosine_loss,torch,full,memory,MB,BT,B x T,2048,11590.3134765625,11590.3134765625,11590.3134765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:23:28,0.5.10 +distill_cosine_loss,torch,full,memory,MB,BT,B x T,4096,19654.375,19654.375,19654.375,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:23:28,0.5.10 +distill_cosine_loss,torch,full,memory,MB,BT,B x T,8192,35782.5,35782.5,35782.5,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:23:28,0.5.10 +layer_norm,liger,forward,speed,ms,N,hidden size,1024,0.018848000094294548,0.018400000408291817,0.020102400332689285,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:11,0.6.0 +layer_norm,liger,forward,speed,ms,N,hidden size,2048,0.029152000322937965,0.02876799926161766,0.029823999851942062,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:11,0.6.0 +layer_norm,liger,forward,speed,ms,N,hidden size,4096,0.05104000121355057,0.05036799982190132,0.05177599936723709,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:11,0.6.0 +layer_norm,liger,forward,speed,ms,N,hidden size,8192,0.0947519987821579,0.09436800330877304,0.09507200121879578,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:11,0.6.0 +layer_norm,liger,forward,speed,ms,N,hidden size,16384,0.18476800620555878,0.18396799266338348,0.1852159947156906,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:11,0.6.0 +layer_norm,huggingface,forward,speed,ms,N,hidden size,1024,0.023584000766277313,0.023423999547958374,0.023840000852942467,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:14,0.6.0 +layer_norm,huggingface,forward,speed,ms,N,hidden size,2048,0.03734400123357773,0.03702399879693985,0.037811201065778746,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:14,0.6.0 +layer_norm,huggingface,forward,speed,ms,N,hidden size,4096,0.06617599725723267,0.06560000032186508,0.06678400188684464,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:14,0.6.0 +layer_norm,huggingface,forward,speed,ms,N,hidden size,8192,0.15267199277877808,0.15190400183200836,0.15347200632095337,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:14,0.6.0 +layer_norm,huggingface,forward,speed,ms,N,hidden size,16384,0.3067840039730072,0.3046143889427185,0.3081152021884918,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:14,0.6.0 +layer_norm,liger,backward,speed,ms,N,hidden size,1024,0.12006399780511856,0.11653760075569153,0.12467200309038162,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:16,0.6.0 +layer_norm,liger,backward,speed,ms,N,hidden size,2048,0.1207360029220581,0.1176128014922142,0.1256511986255646,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:16,0.6.0 +layer_norm,liger,backward,speed,ms,N,hidden size,4096,0.16630400717258453,0.16412800550460815,0.16838400065898895,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:16,0.6.0 +layer_norm,liger,backward,speed,ms,N,hidden size,8192,0.31279999017715454,0.31116798520088196,0.3145279884338379,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:16,0.6.0 +layer_norm,liger,backward,speed,ms,N,hidden size,16384,0.5776320099830627,0.5753471970558167,0.5798912048339844,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:16,0.6.0 +layer_norm,huggingface,backward,speed,ms,N,hidden size,1024,0.0605119988322258,0.059647999703884125,0.061344001442193985,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:18,0.6.0 +layer_norm,huggingface,backward,speed,ms,N,hidden size,2048,0.09967999905347824,0.09849599748849869,0.10099200159311295,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:18,0.6.0 +layer_norm,huggingface,backward,speed,ms,N,hidden size,4096,0.17881600558757782,0.17795200645923615,0.17971199750900269,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:18,0.6.0 +layer_norm,huggingface,backward,speed,ms,N,hidden size,8192,0.33369600772857666,0.3328000009059906,0.33478400111198425,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:18,0.6.0 +layer_norm,huggingface,backward,speed,ms,N,hidden size,16384,0.6424000263214111,0.6412223815917969,0.643455982208252,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:18,0.6.0 +layer_norm,liger,full,speed,ms,N,hidden size,1024,0.26576000452041626,0.2629248082637787,0.2701759934425354,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:21,0.6.0 +layer_norm,liger,full,speed,ms,N,hidden size,2048,0.27427199482917786,0.26999040842056277,0.28091518878936766,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:21,0.6.0 +layer_norm,liger,full,speed,ms,N,hidden size,4096,0.27454400062561035,0.27004799246788025,0.2807359993457794,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:21,0.6.0 +layer_norm,liger,full,speed,ms,N,hidden size,8192,0.40556800365448,0.40403199195861816,0.40723198652267456,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:21,0.6.0 +layer_norm,liger,full,speed,ms,N,hidden size,16384,0.7608960270881653,0.7589311957359314,0.7631679773330688,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:21,0.6.0 +layer_norm,huggingface,full,speed,ms,N,hidden size,1024,0.08025600016117096,0.07942400127649307,0.08111999928951263,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0 +layer_norm,huggingface,full,speed,ms,N,hidden size,2048,0.13315199315547943,0.13180799782276154,0.13468800485134125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0 +layer_norm,huggingface,full,speed,ms,N,hidden size,4096,0.2417600005865097,0.24089600145816803,0.24262399971485138,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0 +layer_norm,huggingface,full,speed,ms,N,hidden size,8192,0.4832639992237091,0.48214399814605713,0.4843647956848145,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0 +layer_norm,huggingface,full,speed,ms,N,hidden size,16384,0.950575977563858,0.9484800100326538,0.9528064012527466,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0 +layer_norm,liger,full,memory,MB,N,hidden size,1024,80.0625,80.0625,80.0625,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0 +layer_norm,liger,full,memory,MB,N,hidden size,2048,160.09375,160.09375,160.09375,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0 +layer_norm,liger,full,memory,MB,N,hidden size,4096,320.15625,320.15625,320.15625,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0 +layer_norm,liger,full,memory,MB,N,hidden size,8192,640.28125,640.28125,640.28125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0 +layer_norm,liger,full,memory,MB,N,hidden size,16384,1280.53125,1280.53125,1280.53125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0 +layer_norm,huggingface,full,memory,MB,N,hidden size,1024,80.0625,80.0625,80.0625,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0 +layer_norm,huggingface,full,memory,MB,N,hidden size,2048,160.09375,160.09375,160.09375,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0 +layer_norm,huggingface,full,memory,MB,N,hidden size,4096,320.15625,320.15625,320.15625,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0 +layer_norm,huggingface,full,memory,MB,N,hidden size,8192,640.28125,640.28125,640.28125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0 +layer_norm,huggingface,full,memory,MB,N,hidden size,16384,1280.53125,1280.53125,1280.53125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-07-17 18:18:23,0.6.0 +fused_add_rms_norm,liger_fused_add_rms_norm,forward,speed,ms,H,hidden size,1024,0.01759999990463257,0.017311999574303627,0.017920000478625298,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:20,0.6.0 +fused_add_rms_norm,liger_fused_add_rms_norm,forward,speed,ms,H,hidden size,2048,0.02924799919128418,0.028863999992609024,0.029983999207615852,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:20,0.6.0 +fused_add_rms_norm,liger_fused_add_rms_norm,forward,speed,ms,H,hidden size,4096,0.05129599943757057,0.050624001771211624,0.05209600180387497,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:20,0.6.0 +fused_add_rms_norm,liger_fused_add_rms_norm,forward,speed,ms,H,hidden size,8192,0.09344000369310379,0.09296000003814697,0.09382399916648865,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:20,0.6.0 +fused_add_rms_norm,liger_fused_add_rms_norm,forward,speed,ms,H,hidden size,16384,0.1791680008172989,0.17814399302005768,0.1796800047159195,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:20,0.6.0 +fused_add_rms_norm,liger_fused_add_rms_norm,forward,speed,ms,H,hidden size,32768,0.43830400705337524,0.43744000792503357,0.43929600715637207,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:20,0.6.0 +fused_add_rms_norm,huggingface,forward,speed,ms,H,hidden size,1024,0.060095999389886856,0.059808000922203064,0.06054399907588959,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:23,0.6.0 +fused_add_rms_norm,huggingface,forward,speed,ms,H,hidden size,2048,0.09084799885749817,0.09027200192213058,0.09161599725484848,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:23,0.6.0 +fused_add_rms_norm,huggingface,forward,speed,ms,H,hidden size,4096,0.17820799350738525,0.17744000256061554,0.17897599935531616,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:23,0.6.0 +fused_add_rms_norm,huggingface,forward,speed,ms,H,hidden size,8192,0.312608003616333,0.3118720054626465,0.31324800848960876,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:23,0.6.0 +fused_add_rms_norm,huggingface,forward,speed,ms,H,hidden size,16384,0.574944019317627,0.5740479826927185,0.5756288051605225,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:23,0.6.0 +fused_add_rms_norm,huggingface,forward,speed,ms,H,hidden size,32768,1.0943039655685425,1.0934272289276123,1.0951999425888062,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:23,0.6.0 +fused_add_rms_norm,liger_rms_norm,forward,speed,ms,H,hidden size,1024,0.0352960005402565,0.03481600061058998,0.03811199963092804,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:26,0.6.0 +fused_add_rms_norm,liger_rms_norm,forward,speed,ms,H,hidden size,2048,0.05430399999022484,0.05392000079154968,0.05503999814391136,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:26,0.6.0 +fused_add_rms_norm,liger_rms_norm,forward,speed,ms,H,hidden size,4096,0.10592000186443329,0.1054655984044075,0.10630399733781815,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:26,0.6.0 +fused_add_rms_norm,liger_rms_norm,forward,speed,ms,H,hidden size,8192,0.19679999351501465,0.19631999731063843,0.19724799692630768,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:26,0.6.0 +fused_add_rms_norm,liger_rms_norm,forward,speed,ms,H,hidden size,16384,0.37436801195144653,0.3733760118484497,0.3752320110797882,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:26,0.6.0 +fused_add_rms_norm,liger_rms_norm,forward,speed,ms,H,hidden size,32768,0.7376000285148621,0.7361343741416931,0.7391359806060791,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:26,0.6.0 +fused_add_rms_norm,liger_fused_add_rms_norm,full,speed,ms,H,hidden size,1024,0.3147200047969818,0.30796160697937014,0.32764801383018494,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:30,0.6.0 +fused_add_rms_norm,liger_fused_add_rms_norm,full,speed,ms,H,hidden size,2048,0.3089919984340668,0.30374398827552795,0.3226880133152008,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:30,0.6.0 +fused_add_rms_norm,liger_fused_add_rms_norm,full,speed,ms,H,hidden size,4096,0.30691200494766235,0.3023296058177948,0.3205504059791565,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:30,0.6.0 +fused_add_rms_norm,liger_fused_add_rms_norm,full,speed,ms,H,hidden size,8192,0.3246079981327057,0.3185984075069428,0.33656961321830753,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:30,0.6.0 +fused_add_rms_norm,liger_fused_add_rms_norm,full,speed,ms,H,hidden size,16384,0.6010559797286987,0.5996800065040588,0.6026239991188049,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:30,0.6.0 +fused_add_rms_norm,liger_fused_add_rms_norm,full,speed,ms,H,hidden size,32768,1.8402559757232666,1.8322880268096924,1.8461120128631592,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:30,0.6.0 +fused_add_rms_norm,huggingface,full,speed,ms,H,hidden size,1024,0.23878400027751923,0.23545600473880768,0.2507520020008087,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:33,0.6.0 +fused_add_rms_norm,huggingface,full,speed,ms,H,hidden size,2048,0.34513600170612335,0.34377598762512207,0.34678399562835693,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:33,0.6.0 +fused_add_rms_norm,huggingface,full,speed,ms,H,hidden size,4096,0.6330879926681519,0.631712019443512,0.6345599889755249,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:33,0.6.0 +fused_add_rms_norm,huggingface,full,speed,ms,H,hidden size,8192,1.1185599565505981,1.1172800064086914,1.1196800470352173,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:33,0.6.0 +fused_add_rms_norm,huggingface,full,speed,ms,H,hidden size,16384,2.0697600841522217,2.0678528785705566,2.0713536739349365,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:33,0.6.0 +fused_add_rms_norm,huggingface,full,speed,ms,H,hidden size,32768,3.9561920166015625,3.953824043273926,3.9581120014190674,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:33,0.6.0 +fused_add_rms_norm,liger_rms_norm,full,speed,ms,H,hidden size,1024,0.38916800916194916,0.3824320137500763,0.4037184059619903,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:36,0.6.0 +fused_add_rms_norm,liger_rms_norm,full,speed,ms,H,hidden size,2048,0.3890720009803772,0.38193280100822447,0.4032831907272339,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:36,0.6.0 +fused_add_rms_norm,liger_rms_norm,full,speed,ms,H,hidden size,4096,0.39715200662612915,0.3928639888763428,0.41097599267959595,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:36,0.6.0 +fused_add_rms_norm,liger_rms_norm,full,speed,ms,H,hidden size,8192,0.6275200247764587,0.6259520053863525,0.6287999749183655,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:36,0.6.0 +fused_add_rms_norm,liger_rms_norm,full,speed,ms,H,hidden size,16384,1.202239990234375,1.199679970741272,1.2048959732055664,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:36,0.6.0 +fused_add_rms_norm,liger_rms_norm,full,speed,ms,H,hidden size,32768,2.7738559246063232,2.7705343723297116,2.777868890762329,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:36,0.6.0 +fused_add_rms_norm,liger_fused_add_rms_norm,backward,speed,ms,H,hidden size,1024,0.15619200468063354,0.15376000106334686,0.1661248028278351,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:39,0.6.0 +fused_add_rms_norm,liger_fused_add_rms_norm,backward,speed,ms,H,hidden size,2048,0.15825600177049637,0.15600000321865082,0.16911999881267548,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:39,0.6.0 +fused_add_rms_norm,liger_fused_add_rms_norm,backward,speed,ms,H,hidden size,4096,0.16700799763202667,0.16502399742603302,0.1709440052509308,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:39,0.6.0 +fused_add_rms_norm,liger_fused_add_rms_norm,backward,speed,ms,H,hidden size,8192,0.1712000072002411,0.1700800061225891,0.17215999960899353,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:39,0.6.0 +fused_add_rms_norm,liger_fused_add_rms_norm,backward,speed,ms,H,hidden size,16384,0.42505601048469543,0.4233280122280121,0.42691200971603394,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:39,0.6.0 +fused_add_rms_norm,liger_fused_add_rms_norm,backward,speed,ms,H,hidden size,32768,1.4057759642601013,1.3944000005722046,1.4099839925765991,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:39,0.6.0 +fused_add_rms_norm,huggingface,backward,speed,ms,H,hidden size,1024,0.1520960032939911,0.15136000514030457,0.1528960019350052,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:42,0.6.0 +fused_add_rms_norm,huggingface,backward,speed,ms,H,hidden size,2048,0.2533760070800781,0.2524160146713257,0.25436800718307495,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:42,0.6.0 +fused_add_rms_norm,huggingface,backward,speed,ms,H,hidden size,4096,0.4551039934158325,0.4540799856185913,0.45612800121307373,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:42,0.6.0 +fused_add_rms_norm,huggingface,backward,speed,ms,H,hidden size,8192,0.8053439855575562,0.8038079738616943,0.806656002998352,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:42,0.6.0 +fused_add_rms_norm,huggingface,backward,speed,ms,H,hidden size,16384,1.4933120012283325,1.492095947265625,1.49452805519104,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:42,0.6.0 +fused_add_rms_norm,huggingface,backward,speed,ms,H,hidden size,32768,2.8600640296936035,2.8583295822143557,2.8612607955932616,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:42,0.6.0 +fused_add_rms_norm,liger_rms_norm,backward,speed,ms,H,hidden size,1024,0.20175999402999878,0.199072003364563,0.2154303938150406,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0 +fused_add_rms_norm,liger_rms_norm,backward,speed,ms,H,hidden size,2048,0.20263999700546265,0.20000000298023224,0.21675519943237304,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0 +fused_add_rms_norm,liger_rms_norm,backward,speed,ms,H,hidden size,4096,0.25276800990104675,0.2515519857406616,0.2539199888706207,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0 +fused_add_rms_norm,liger_rms_norm,backward,speed,ms,H,hidden size,8192,0.4322720021009445,0.43088001012802124,0.4336000084877014,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0 +fused_add_rms_norm,liger_rms_norm,backward,speed,ms,H,hidden size,16384,0.8288000226020813,0.8266303777694701,0.8311295866966247,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0 +fused_add_rms_norm,liger_rms_norm,backward,speed,ms,H,hidden size,32768,2.03987193107605,2.0360767364501955,2.0436416149139403,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0 +fused_add_rms_norm,liger_fused_add_rms_norm,full,memory,MB,H,hidden size,1024,72.546875,72.546875,72.546875,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0 +fused_add_rms_norm,liger_fused_add_rms_norm,full,memory,MB,H,hidden size,2048,145.0859375,145.0859375,145.0859375,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0 +fused_add_rms_norm,liger_fused_add_rms_norm,full,memory,MB,H,hidden size,4096,290.1640625,290.1640625,290.1640625,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0 +fused_add_rms_norm,liger_fused_add_rms_norm,full,memory,MB,H,hidden size,8192,580.3203125,580.3203125,580.3203125,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0 +fused_add_rms_norm,liger_fused_add_rms_norm,full,memory,MB,H,hidden size,16384,1160.6328125,1160.6328125,1160.6328125,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0 +fused_add_rms_norm,liger_fused_add_rms_norm,full,memory,MB,H,hidden size,32768,2321.2578125,2321.2578125,2321.2578125,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0 +fused_add_rms_norm,huggingface,full,memory,MB,H,hidden size,1024,104.03173828125,104.03173828125,104.03173828125,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0 +fused_add_rms_norm,huggingface,full,memory,MB,H,hidden size,2048,208.05517578125,208.05517578125,208.05517578125,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0 +fused_add_rms_norm,huggingface,full,memory,MB,H,hidden size,4096,416.10205078125,416.10205078125,416.10205078125,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0 +fused_add_rms_norm,huggingface,full,memory,MB,H,hidden size,8192,832.19580078125,832.19580078125,832.19580078125,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0 +fused_add_rms_norm,huggingface,full,memory,MB,H,hidden size,16384,1664.3125,1664.3125,1664.3125,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0 +fused_add_rms_norm,huggingface,full,memory,MB,H,hidden size,32768,3328.625,3328.625,3328.625,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0 +fused_add_rms_norm,liger_rms_norm,full,memory,MB,H,hidden size,1024,104.03564453125,104.03564453125,104.03564453125,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0 +fused_add_rms_norm,liger_rms_norm,full,memory,MB,H,hidden size,2048,208.06298828125,208.06298828125,208.06298828125,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0 +fused_add_rms_norm,liger_rms_norm,full,memory,MB,H,hidden size,4096,416.11767578125,416.11767578125,416.11767578125,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0 +fused_add_rms_norm,liger_rms_norm,full,memory,MB,H,hidden size,8192,832.22705078125,832.22705078125,832.22705078125,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0 +fused_add_rms_norm,liger_rms_norm,full,memory,MB,H,hidden size,16384,1544.44580078125,1544.44580078125,1544.44580078125,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0 +fused_add_rms_norm,liger_rms_norm,full,memory,MB,H,hidden size,32768,2960.8837890625,2960.8837890625,2960.8837890625,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0 +fused_linear_grpo_loss_token,liger,forward,speed,ms,B,B,2,40.75366401672363,40.749671173095706,40.75765686035156,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-04 23:58:45,0.6.1 +fused_linear_grpo_loss_token,liger,forward,speed,ms,B,B,4,80.95231628417969,80.95231628417969,80.95231628417969,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-04 23:58:45,0.6.1 +fused_linear_grpo_loss_token,liger,forward,speed,ms,B,B,8,163.58604431152344,163.58604431152344,163.58604431152344,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-04 23:58:45,0.6.1 +fused_linear_grpo_loss_token,liger,forward,speed,ms,B,B,16,323.6761474609375,323.6761474609375,323.6761474609375,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-04 23:58:45,0.6.1 +fused_linear_grpo_loss_token,torch,forward,speed,ms,B,B,2,23.71225643157959,23.612825775146483,23.8354434967041,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-04 23:59:51,0.6.1 +fused_linear_grpo_loss_token,torch,forward,speed,ms,B,B,4,46.86131286621094,46.80355911254883,46.91906661987304,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-04 23:59:51,0.6.1 +fused_linear_grpo_loss_token,torch,forward,speed,ms,B,B,8,94.54898834228516,94.54898834228516,94.54898834228516,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-04 23:59:51,0.6.1 +fused_linear_grpo_loss_token,torch,forward,speed,ms,B,B,16,189.99501037597656,189.99501037597656,189.99501037597656,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-04 23:59:51,0.6.1 +fused_linear_grpo_loss_token,liger,full,speed,ms,B,B,2,42.67263984680176,42.54085083007813,42.80442886352539,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:00:58,0.6.1 +fused_linear_grpo_loss_token,liger,full,speed,ms,B,B,4,82.2446060180664,82.2446060180664,82.2446060180664,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:00:58,0.6.1 +fused_linear_grpo_loss_token,liger,full,speed,ms,B,B,8,167.00416564941406,167.00416564941406,167.00416564941406,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:00:58,0.6.1 +fused_linear_grpo_loss_token,liger,full,speed,ms,B,B,16,327.0911865234375,327.0911865234375,327.0911865234375,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:00:58,0.6.1 +fused_linear_grpo_loss_token,torch,full,speed,ms,B,B,2,45.36115264892578,45.241344451904304,45.480960845947266,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:02:07,0.6.1 +fused_linear_grpo_loss_token,torch,full,speed,ms,B,B,4,90.00038146972656,90.00038146972656,90.00038146972656,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:02:07,0.6.1 +fused_linear_grpo_loss_token,torch,full,speed,ms,B,B,8,177.22674560546875,177.22674560546875,177.22674560546875,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:02:07,0.6.1 +fused_linear_grpo_loss_token,torch,full,speed,ms,B,B,16,356.5383605957031,356.5383605957031,356.5383605957031,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:02:07,0.6.1 +fused_linear_grpo_loss_token,liger,backward,speed,ms,B,B,2,1.814527988433838,1.8124799728393555,1.8167808055877686,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:03:11,0.6.1 +fused_linear_grpo_loss_token,liger,backward,speed,ms,B,B,4,1.84934401512146,1.8472959995269775,1.8524160385131836,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:03:11,0.6.1 +fused_linear_grpo_loss_token,liger,backward,speed,ms,B,B,8,1.891327977180481,1.8872319459915161,1.893990397453308,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:03:11,0.6.1 +fused_linear_grpo_loss_token,liger,backward,speed,ms,B,B,16,1.9722239971160889,1.9660799503326416,1.9763200283050537,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:03:11,0.6.1 +fused_linear_grpo_loss_token,torch,backward,speed,ms,B,B,2,22.014975547790527,21.710438537597657,22.19417533874512,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:04:16,0.6.1 +fused_linear_grpo_loss_token,torch,backward,speed,ms,B,B,4,41.83603096008301,41.752165222167974,41.91989669799805,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:04:16,0.6.1 +fused_linear_grpo_loss_token,torch,backward,speed,ms,B,B,8,81.66400146484375,81.66400146484375,81.66400146484375,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:04:16,0.6.1 +fused_linear_grpo_loss_token,torch,backward,speed,ms,B,B,16,162.6429443359375,162.6429443359375,162.6429443359375,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:04:16,0.6.1 +fused_linear_grpo_loss_token,liger,full,memory,MB,B,B,2,7344.77685546875,7344.77685546875,7344.77685546875,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:05:31,0.6.1 +fused_linear_grpo_loss_token,liger,full,memory,MB,B,B,4,7408.80029296875,7408.80029296875,7408.80029296875,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:05:31,0.6.1 +fused_linear_grpo_loss_token,liger,full,memory,MB,B,B,8,7536.84716796875,7536.84716796875,7536.84716796875,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:05:31,0.6.1 +fused_linear_grpo_loss_token,liger,full,memory,MB,B,B,16,7792.94091796875,7792.94091796875,7792.94091796875,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:05:31,0.6.1 +fused_linear_grpo_loss_token,torch,full,memory,MB,B,B,2,9083.28125,9083.28125,9083.28125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:06:37,0.6.1 +fused_linear_grpo_loss_token,torch,full,memory,MB,B,B,4,13138.3125,13138.3125,13138.3125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:06:37,0.6.1 +fused_linear_grpo_loss_token,torch,full,memory,MB,B,B,8,21250.375,21250.375,21250.375,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:06:37,0.6.1 +fused_linear_grpo_loss_token,torch,full,memory,MB,B,B,16,37474.5,37474.5,37474.5,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""token"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:06:37,0.6.1 +fused_linear_grpo_loss_sequence,liger,forward,speed,ms,B,B,2,40.72038269042969,40.71178131103516,40.728984069824214,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:07:48,0.6.1 +fused_linear_grpo_loss_sequence,liger,forward,speed,ms,B,B,4,81.69369506835938,81.69369506835938,81.69369506835938,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:07:48,0.6.1 +fused_linear_grpo_loss_sequence,liger,forward,speed,ms,B,B,8,162.79653930664062,162.79653930664062,162.79653930664062,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:07:48,0.6.1 +fused_linear_grpo_loss_sequence,liger,forward,speed,ms,B,B,16,323.6546630859375,323.6546630859375,323.6546630859375,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:07:48,0.6.1 +fused_linear_grpo_loss_sequence,torch,forward,speed,ms,B,B,2,23.70047950744629,23.628594589233398,23.732429122924806,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:08:54,0.6.1 +fused_linear_grpo_loss_sequence,torch,forward,speed,ms,B,B,4,47.36921691894531,47.085364532470706,47.65306930541992,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:08:54,0.6.1 +fused_linear_grpo_loss_sequence,torch,forward,speed,ms,B,B,8,94.83366394042969,94.83366394042969,94.83366394042969,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:08:54,0.6.1 +fused_linear_grpo_loss_sequence,torch,forward,speed,ms,B,B,16,190.0963897705078,190.0963897705078,190.0963897705078,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:08:54,0.6.1 +fused_linear_grpo_loss_sequence,liger,full,speed,ms,B,B,2,42.318336486816406,42.15214080810547,42.48453216552734,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:10:02,0.6.1 +fused_linear_grpo_loss_sequence,liger,full,speed,ms,B,B,4,82.4616928100586,82.4616928100586,82.4616928100586,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:10:02,0.6.1 +fused_linear_grpo_loss_sequence,liger,full,speed,ms,B,B,8,163.43756103515625,163.43756103515625,163.43756103515625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:10:02,0.6.1 +fused_linear_grpo_loss_sequence,liger,full,speed,ms,B,B,16,325.4384765625,325.4384765625,325.4384765625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:10:02,0.6.1 +fused_linear_grpo_loss_sequence,torch,full,speed,ms,B,B,2,45.99193572998047,45.80761489868165,46.176256561279295,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:11:10,0.6.1 +fused_linear_grpo_loss_sequence,torch,full,speed,ms,B,B,4,88.57190704345703,88.57190704345703,88.57190704345703,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:11:10,0.6.1 +fused_linear_grpo_loss_sequence,torch,full,speed,ms,B,B,8,176.94105529785156,176.94105529785156,176.94105529785156,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:11:10,0.6.1 +fused_linear_grpo_loss_sequence,torch,full,speed,ms,B,B,16,356.0478820800781,356.0478820800781,356.0478820800781,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:11:10,0.6.1 +fused_linear_grpo_loss_sequence,liger,backward,speed,ms,B,B,2,1.8242560029029846,1.8102271556854248,1.8309119939804077,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:12:14,0.6.1 +fused_linear_grpo_loss_sequence,liger,backward,speed,ms,B,B,4,1.84934401512146,1.846886396408081,1.8534400463104248,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:12:14,0.6.1 +fused_linear_grpo_loss_sequence,liger,backward,speed,ms,B,B,8,1.891327977180481,1.8892799615859985,1.8933759927749634,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:12:14,0.6.1 +fused_linear_grpo_loss_sequence,liger,backward,speed,ms,B,B,16,1.9752960205078125,1.9722239971160889,1.977344036102295,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:12:14,0.6.1 +fused_linear_grpo_loss_sequence,torch,backward,speed,ms,B,B,2,22.0262393951416,21.80997085571289,22.20482559204102,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:13:20,0.6.1 +fused_linear_grpo_loss_sequence,torch,backward,speed,ms,B,B,4,41.54521560668945,41.224806213378905,41.865625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:13:20,0.6.1 +fused_linear_grpo_loss_sequence,torch,backward,speed,ms,B,B,8,81.21753692626953,81.21753692626953,81.21753692626953,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:13:20,0.6.1 +fused_linear_grpo_loss_sequence,torch,backward,speed,ms,B,B,16,160.82022094726562,160.82022094726562,160.82022094726562,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:13:20,0.6.1 +fused_linear_grpo_loss_sequence,liger,full,memory,MB,B,B,2,7344.77685546875,7344.77685546875,7344.77685546875,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:14:28,0.6.1 +fused_linear_grpo_loss_sequence,liger,full,memory,MB,B,B,4,7408.80029296875,7408.80029296875,7408.80029296875,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:14:28,0.6.1 +fused_linear_grpo_loss_sequence,liger,full,memory,MB,B,B,8,7536.84716796875,7536.84716796875,7536.84716796875,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:14:28,0.6.1 +fused_linear_grpo_loss_sequence,liger,full,memory,MB,B,B,16,7792.94091796875,7792.94091796875,7792.94091796875,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:14:28,0.6.1 +fused_linear_grpo_loss_sequence,torch,full,memory,MB,B,B,2,9083.28125,9083.28125,9083.28125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:15:31,0.6.1 +fused_linear_grpo_loss_sequence,torch,full,memory,MB,B,B,4,13138.3125,13138.3125,13138.3125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:15:31,0.6.1 +fused_linear_grpo_loss_sequence,torch,full,memory,MB,B,B,8,21250.375,21250.375,21250.375,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:15:31,0.6.1 +fused_linear_grpo_loss_sequence,torch,full,memory,MB,B,B,16,37474.5,37474.5,37474.5,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""importance_sampling_level"": ""sequence"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2025-08-05 00:15:31,0.6.1 +llama4_rope,liger,forward,speed,ms,H,hidden size,512,0.08249600231647491,0.08102399855852127,0.08432000130414963,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:01,0.6.1 +llama4_rope,liger,forward,speed,ms,H,hidden size,2048,0.08169600367546082,0.08037760108709335,0.08329600095748901,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:01,0.6.1 +llama4_rope,liger,forward,speed,ms,H,hidden size,8192,0.08128000050783157,0.07980799674987793,0.08329600095748901,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:01,0.6.1 +llama4_rope,huggingface,forward,speed,ms,H,hidden size,512,0.03759999945759773,0.03612799942493439,0.03907199949026108,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:03,0.6.1 +llama4_rope,huggingface,forward,speed,ms,H,hidden size,2048,0.06185600161552429,0.061267200857400894,0.06252799928188324,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:03,0.6.1 +llama4_rope,huggingface,forward,speed,ms,H,hidden size,8192,0.206496000289917,0.20582400262355804,0.20716799795627594,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:03,0.6.1 +llama4_rope,liger,backward,speed,ms,H,hidden size,512,0.15404799580574036,0.15241600573062897,0.15615999698638916,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:04,0.6.1 +llama4_rope,liger,backward,speed,ms,H,hidden size,2048,0.1536320000886917,0.15190400183200836,0.1558080017566681,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:04,0.6.1 +llama4_rope,liger,backward,speed,ms,H,hidden size,8192,0.15263999998569489,0.15094399452209473,0.15491199493408203,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:04,0.6.1 +llama4_rope,huggingface,backward,speed,ms,H,hidden size,512,0.13760000467300415,0.13574400544166565,0.14009599387645721,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:05,0.6.1 +llama4_rope,huggingface,backward,speed,ms,H,hidden size,2048,0.13600000739097595,0.13449600338935852,0.1382720023393631,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:05,0.6.1 +llama4_rope,huggingface,backward,speed,ms,H,hidden size,8192,0.21011200547218323,0.20924800634384155,0.21110400557518005,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:05,0.6.1 +llama4_rope,liger,full,speed,ms,H,hidden size,512,0.3652159869670868,0.3619840145111084,0.3699840009212494,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:07,0.6.1 +llama4_rope,liger,full,speed,ms,H,hidden size,2048,0.3599040061235428,0.2881920039653778,0.36559998989105225,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:07,0.6.1 +llama4_rope,liger,full,speed,ms,H,hidden size,8192,0.2874239981174469,0.2852480113506317,0.29029120206832887,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:07,0.6.1 +llama4_rope,huggingface,full,speed,ms,H,hidden size,512,0.24691200256347656,0.24489599466323853,0.24961919784545897,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:08,0.6.1 +llama4_rope,huggingface,full,speed,ms,H,hidden size,2048,0.24774399399757385,0.24582399427890778,0.2505407989025116,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:08,0.6.1 +llama4_rope,huggingface,full,speed,ms,H,hidden size,8192,0.41414400935173035,0.41337600350379944,0.41491198539733887,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:08,0.6.1 +llama4_rope,liger,full,memory,MB,H,hidden size,512,37.23486328125,37.23486328125,37.23486328125,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:08,0.6.1 +llama4_rope,liger,full,memory,MB,H,hidden size,2048,52.89111328125,52.89111328125,52.89111328125,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:08,0.6.1 +llama4_rope,liger,full,memory,MB,H,hidden size,8192,115.51611328125,115.51611328125,115.51611328125,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:08,0.6.1 +llama4_rope,huggingface,full,memory,MB,H,hidden size,512,49.64111328125,49.64111328125,49.64111328125,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:08,0.6.1 +llama4_rope,huggingface,full,memory,MB,H,hidden size,2048,102.51611328125,102.51611328125,102.51611328125,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:08,0.6.1 +llama4_rope,huggingface,full,memory,MB,H,hidden size,8192,314.01611328125,314.01611328125,314.01611328125,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:08,0.6.1 +llama4_rope,liger,forward,speed,ms,T,sequence length,1024,0.07417599856853485,0.07248000055551529,0.07596799731254578,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:10,0.6.1 +llama4_rope,liger,forward,speed,ms,T,sequence length,2048,0.08182399719953537,0.08006399869918823,0.08380799740552902,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:10,0.6.1 +llama4_rope,liger,forward,speed,ms,T,sequence length,4096,0.11708799749612808,0.1167680025100708,0.11744000017642975,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:10,0.6.1 +llama4_rope,liger,forward,speed,ms,T,sequence length,8192,0.2165440022945404,0.21596799790859222,0.21715199947357178,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:10,0.6.1 +llama4_rope,liger,forward,speed,ms,T,sequence length,16384,0.41756799817085266,0.41705599427223206,0.41811200976371765,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:10,0.6.1 +llama4_rope,huggingface,forward,speed,ms,T,sequence length,1024,0.11644800007343292,0.11590400338172913,0.11708799749612808,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:12,0.6.1 +llama4_rope,huggingface,forward,speed,ms,T,sequence length,2048,0.20659199357032776,0.20608000457286835,0.2072640061378479,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:12,0.6.1 +llama4_rope,huggingface,forward,speed,ms,T,sequence length,4096,0.38553598523139954,0.3846847891807556,0.38624000549316406,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:12,0.6.1 +llama4_rope,huggingface,forward,speed,ms,T,sequence length,8192,0.7411519885063171,0.7403839826583862,0.7420480251312256,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:12,0.6.1 +llama4_rope,huggingface,forward,speed,ms,T,sequence length,16384,1.4553920030593872,1.4543871641159059,1.4562879800796509,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:12,0.6.1 +llama4_rope,liger,backward,speed,ms,T,sequence length,1024,0.11840000003576279,0.11711999773979187,0.12031999975442886,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:15,0.6.1 +llama4_rope,liger,backward,speed,ms,T,sequence length,2048,0.12336000055074692,0.12198399752378464,0.12489599734544754,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:15,0.6.1 +llama4_rope,liger,backward,speed,ms,T,sequence length,4096,0.12380799651145935,0.12240000069141388,0.12559999525547028,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:15,0.6.1 +llama4_rope,liger,backward,speed,ms,T,sequence length,8192,0.2170879989862442,0.2165759950876236,0.21753600239753723,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:15,0.6.1 +llama4_rope,liger,backward,speed,ms,T,sequence length,16384,0.4175359904766083,0.41705599427223206,0.4181375920772552,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:15,0.6.1 +llama4_rope,huggingface,backward,speed,ms,T,sequence length,1024,0.1189119964838028,0.11769600212574005,0.12003199756145477,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:17,0.6.1 +llama4_rope,huggingface,backward,speed,ms,T,sequence length,2048,0.21011200547218323,0.20927999913692474,0.21119999885559082,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:17,0.6.1 +llama4_rope,huggingface,backward,speed,ms,T,sequence length,4096,0.39740800857543945,0.3963199853897095,0.39824000000953674,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:17,0.6.1 +llama4_rope,huggingface,backward,speed,ms,T,sequence length,8192,0.7540159821510315,0.7528960108757019,0.7550719976425171,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:17,0.6.1 +llama4_rope,huggingface,backward,speed,ms,T,sequence length,16384,1.4822720289230347,1.4810559749603271,1.4833600521087646,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:17,0.6.1 +llama4_rope,liger,full,speed,ms,T,sequence length,1024,0.2874400019645691,0.2853440046310425,0.29052799940109253,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:19,0.6.1 +llama4_rope,liger,full,speed,ms,T,sequence length,2048,0.28646400570869446,0.2845759987831116,0.28963199257850647,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:19,0.6.1 +llama4_rope,liger,full,speed,ms,T,sequence length,4096,0.29897600412368774,0.29660800099372864,0.302131199836731,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:19,0.6.1 +llama4_rope,liger,full,speed,ms,T,sequence length,8192,0.4315840005874634,0.4304639995098114,0.43270400166511536,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:19,0.6.1 +llama4_rope,liger,full,speed,ms,T,sequence length,16384,0.833184003829956,0.8322240114212036,0.8345024228096007,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:19,0.6.1 +llama4_rope,huggingface,full,speed,ms,T,sequence length,1024,0.24592000246047974,0.24396799504756927,0.24876800179481506,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1 +llama4_rope,huggingface,full,speed,ms,T,sequence length,2048,0.4138239920139313,0.41308799386024475,0.4145599901676178,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1 +llama4_rope,huggingface,full,speed,ms,T,sequence length,4096,0.7800959944725037,0.7790719866752625,0.7810239791870117,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1 +llama4_rope,huggingface,full,speed,ms,T,sequence length,8192,1.4911680221557617,1.4902976036071778,1.4922879934310913,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1 +llama4_rope,huggingface,full,speed,ms,T,sequence length,16384,2.9344160556793213,2.9333438873291016,2.9353599548339844,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1 +llama4_rope,liger,full,memory,MB,T,sequence length,1024,73.75830078125,73.75830078125,73.75830078125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1 +llama4_rope,liger,full,memory,MB,T,sequence length,2048,115.51611328125,115.51611328125,115.51611328125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1 +llama4_rope,liger,full,memory,MB,T,sequence length,4096,199.03173828125,199.03173828125,199.03173828125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1 +llama4_rope,liger,full,memory,MB,T,sequence length,8192,366.06298828125,366.06298828125,366.06298828125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1 +llama4_rope,liger,full,memory,MB,T,sequence length,16384,700.12548828125,700.12548828125,700.12548828125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1 +llama4_rope,huggingface,full,memory,MB,T,sequence length,1024,173.00830078125,173.00830078125,173.00830078125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1 +llama4_rope,huggingface,full,memory,MB,T,sequence length,2048,314.01611328125,314.01611328125,314.01611328125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1 +llama4_rope,huggingface,full,memory,MB,T,sequence length,4096,596.03173828125,596.03173828125,596.03173828125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1 +llama4_rope,huggingface,full,memory,MB,T,sequence length,8192,1160.06298828125,1160.06298828125,1160.06298828125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1 +llama4_rope,huggingface,full,memory,MB,T,sequence length,16384,2288.12548828125,2288.12548828125,2288.12548828125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1 +tiled_geglu,liger,full,speed,ms,T,sequence length,1024,2.1678080558776855,2.166579246520996,2.1682305335998535,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:48,0.6.3 +tiled_geglu,liger,full,speed,ms,T,sequence length,2048,4.344256401062012,4.343987464904785,4.34452486038208,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:48,0.6.3 +tiled_geglu,liger,full,speed,ms,T,sequence length,4096,8.653023719787598,8.653023719787598,8.653023719787598,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:48,0.6.3 +tiled_geglu,liger,full,speed,ms,T,sequence length,8192,16.909311294555664,16.909311294555664,16.909311294555664,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:48,0.6.3 +tiled_geglu,liger,full,speed,ms,T,sequence length,16384,33.63123321533203,33.63123321533203,33.63123321533203,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:48,0.6.3 +tiled_geglu,liger_tiled,full,speed,ms,T,sequence length,1024,3.353935956954956,3.353523015975952,3.35434889793396,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:49,0.6.3 +tiled_geglu,liger_tiled,full,speed,ms,T,sequence length,2048,6.023168087005615,6.023168087005615,6.023168087005615,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:49,0.6.3 +tiled_geglu,liger_tiled,full,speed,ms,T,sequence length,4096,11.495424270629883,11.495424270629883,11.495424270629883,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:49,0.6.3 +tiled_geglu,liger_tiled,full,speed,ms,T,sequence length,8192,23.68614387512207,23.68614387512207,23.68614387512207,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:49,0.6.3 +tiled_geglu,liger_tiled,full,speed,ms,T,sequence length,16384,47.478782653808594,47.478782653808594,47.478782653808594,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:49,0.6.3 +tiled_geglu,liger,forward,speed,ms,T,sequence length,1024,0.6614400148391724,0.6594560146331787,0.6635519862174988,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:52,0.6.3 +tiled_geglu,liger,forward,speed,ms,T,sequence length,2048,1.3471999168395996,1.346560001373291,1.3475840091705322,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:52,0.6.3 +tiled_geglu,liger,forward,speed,ms,T,sequence length,4096,2.752511978149414,2.7261502742767334,2.7844607830047607,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:52,0.6.3 +tiled_geglu,liger,forward,speed,ms,T,sequence length,8192,5.433343887329102,5.433343887329102,5.433343887329102,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:52,0.6.3 +tiled_geglu,liger,forward,speed,ms,T,sequence length,16384,10.712063789367676,10.712063789367676,10.712063789367676,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:52,0.6.3 +tiled_geglu,liger_tiled,forward,speed,ms,T,sequence length,1024,0.7403519749641418,0.7402047514915466,0.7413759827613831,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:53,0.6.3 +tiled_geglu,liger_tiled,forward,speed,ms,T,sequence length,2048,1.3941760063171387,1.3895679712295532,1.398144006729126,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:53,0.6.3 +tiled_geglu,liger_tiled,forward,speed,ms,T,sequence length,4096,2.7586560249328613,2.7585408687591553,2.759884834289551,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:53,0.6.3 +tiled_geglu,liger_tiled,forward,speed,ms,T,sequence length,8192,5.789696216583252,5.789696216583252,5.789696216583252,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:53,0.6.3 +tiled_geglu,liger_tiled,forward,speed,ms,T,sequence length,16384,11.810815811157227,11.810815811157227,11.810815811157227,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:53,0.6.3 +tiled_geglu,liger,backward,speed,ms,T,sequence length,1024,1.491968035697937,1.4916608333587646,1.4940160512924194,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:56,0.6.3 +tiled_geglu,liger,backward,speed,ms,T,sequence length,2048,3.0185279846191406,3.0131328105926514,3.0555264949798584,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:56,0.6.3 +tiled_geglu,liger,backward,speed,ms,T,sequence length,4096,6.021120071411133,6.021120071411133,6.021120071411133,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:56,0.6.3 +tiled_geglu,liger,backward,speed,ms,T,sequence length,8192,11.512767791748047,11.512767791748047,11.512767791748047,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:56,0.6.3 +tiled_geglu,liger,backward,speed,ms,T,sequence length,16384,22.806528091430664,22.806528091430664,22.806528091430664,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:56,0.6.3 +tiled_geglu,liger_tiled,backward,speed,ms,T,sequence length,1024,2.6060800552368164,2.6053311824798584,2.607308864593506,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:57,0.6.3 +tiled_geglu,liger_tiled,backward,speed,ms,T,sequence length,2048,4.665375709533691,4.664742469787598,4.666009426116943,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:57,0.6.3 +tiled_geglu,liger_tiled,backward,speed,ms,T,sequence length,4096,8.71731185913086,8.71731185913086,8.71731185913086,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:57,0.6.3 +tiled_geglu,liger_tiled,backward,speed,ms,T,sequence length,8192,17.99782371520996,17.99782371520996,17.99782371520996,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:57,0.6.3 +tiled_geglu,liger_tiled,backward,speed,ms,T,sequence length,16384,35.64400100708008,35.64400100708008,35.64400100708008,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:57,0.6.3 +tiled_geglu,liger,full,memory,MB,T,sequence length,1024,232.25,232.25,232.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:00,0.6.3 +tiled_geglu,liger,full,memory,MB,T,sequence length,2048,336.25,336.25,336.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:00,0.6.3 +tiled_geglu,liger,full,memory,MB,T,sequence length,4096,544.25,544.25,544.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:00,0.6.3 +tiled_geglu,liger,full,memory,MB,T,sequence length,8192,960.25,960.25,960.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:00,0.6.3 +tiled_geglu,liger,full,memory,MB,T,sequence length,16384,1792.25,1792.25,1792.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:00,0.6.3 +tiled_geglu,liger_tiled,full,memory,MB,T,sequence length,1024,186.25,186.25,186.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:00,0.6.3 +tiled_geglu,liger_tiled,full,memory,MB,T,sequence length,2048,244.25,244.25,244.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:00,0.6.3 +tiled_geglu,liger_tiled,full,memory,MB,T,sequence length,4096,360.25,360.25,360.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:00,0.6.3 +tiled_geglu,liger_tiled,full,memory,MB,T,sequence length,8192,592.25,592.25,592.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:00,0.6.3 +tiled_geglu,liger_tiled,full,memory,MB,T,sequence length,16384,1056.25,1056.25,1056.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:00,0.6.3 +tiled_geglu,liger,forward,memory,MB,T,sequence length,1024,128.25,128.25,128.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:03,0.6.3 +tiled_geglu,liger,forward,memory,MB,T,sequence length,2048,192.25,192.25,192.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:03,0.6.3 +tiled_geglu,liger,forward,memory,MB,T,sequence length,4096,320.25,320.25,320.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:03,0.6.3 +tiled_geglu,liger,forward,memory,MB,T,sequence length,8192,576.25,576.25,576.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:03,0.6.3 +tiled_geglu,liger,forward,memory,MB,T,sequence length,16384,1088.25,1088.25,1088.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:03,0.6.3 +tiled_geglu,liger_tiled,forward,memory,MB,T,sequence length,1024,92.25,92.25,92.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:04,0.6.3 +tiled_geglu,liger_tiled,forward,memory,MB,T,sequence length,2048,120.25,120.25,120.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:04,0.6.3 +tiled_geglu,liger_tiled,forward,memory,MB,T,sequence length,4096,176.25,176.25,176.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:04,0.6.3 +tiled_geglu,liger_tiled,forward,memory,MB,T,sequence length,8192,288.25,288.25,288.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:04,0.6.3 +tiled_geglu,liger_tiled,forward,memory,MB,T,sequence length,16384,512.25,512.25,512.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:04,0.6.3 +tiled_geglu,liger,backward,memory,MB,T,sequence length,1024,232.25,232.25,232.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:05,0.6.3 +tiled_geglu,liger,backward,memory,MB,T,sequence length,2048,336.25,336.25,336.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:05,0.6.3 +tiled_geglu,liger,backward,memory,MB,T,sequence length,4096,544.25,544.25,544.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:05,0.6.3 +tiled_geglu,liger,backward,memory,MB,T,sequence length,8192,960.25,960.25,960.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:05,0.6.3 +tiled_geglu,liger,backward,memory,MB,T,sequence length,16384,1792.25,1792.25,1792.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:05,0.6.3 +tiled_geglu,liger_tiled,backward,memory,MB,T,sequence length,1024,186.25,186.25,186.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:06,0.6.3 +tiled_geglu,liger_tiled,backward,memory,MB,T,sequence length,2048,244.25,244.25,244.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:06,0.6.3 +tiled_geglu,liger_tiled,backward,memory,MB,T,sequence length,4096,360.25,360.25,360.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:06,0.6.3 +tiled_geglu,liger_tiled,backward,memory,MB,T,sequence length,8192,592.25,592.25,592.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:06,0.6.3 +tiled_geglu,liger_tiled,backward,memory,MB,T,sequence length,16384,1056.25,1056.25,1056.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:06,0.6.3 +tiled_swiglu,liger,full,speed,ms,T,sequence length,1024,2.165760040283203,2.164659261703491,2.167193651199341,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:10,0.6.3 +tiled_swiglu,liger,full,speed,ms,T,sequence length,2048,4.371456146240234,4.368383884429932,4.374527931213379,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:10,0.6.3 +tiled_swiglu,liger,full,speed,ms,T,sequence length,4096,8.935423851013184,8.935423851013184,8.935423851013184,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:10,0.6.3 +tiled_swiglu,liger,full,speed,ms,T,sequence length,8192,17.078943252563477,17.078943252563477,17.078943252563477,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:10,0.6.3 +tiled_swiglu,liger,full,speed,ms,T,sequence length,16384,33.74857711791992,33.74857711791992,33.74857711791992,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:10,0.6.3 +tiled_swiglu,liger_tiled,full,speed,ms,T,sequence length,1024,3.3510398864746094,3.3507328033447266,3.3513472080230713,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:11,0.6.3 +tiled_swiglu,liger_tiled,full,speed,ms,T,sequence length,2048,6.023168087005615,6.023168087005615,6.023168087005615,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:11,0.6.3 +tiled_swiglu,liger_tiled,full,speed,ms,T,sequence length,4096,11.609087944030762,11.609087944030762,11.609087944030762,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:11,0.6.3 +tiled_swiglu,liger_tiled,full,speed,ms,T,sequence length,8192,23.8591365814209,23.8591365814209,23.8591365814209,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:11,0.6.3 +tiled_swiglu,liger_tiled,full,speed,ms,T,sequence length,16384,47.721473693847656,47.721473693847656,47.721473693847656,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:11,0.6.3 +tiled_swiglu,liger,forward,speed,ms,T,sequence length,1024,0.6594560146331787,0.6594560146331787,0.6604800224304199,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:14,0.6.3 +tiled_swiglu,liger,forward,speed,ms,T,sequence length,2048,1.3537280559539795,1.3527040481567383,1.3547519445419312,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:14,0.6.3 +tiled_swiglu,liger,forward,speed,ms,T,sequence length,4096,2.7152960300445557,2.715123176574707,2.7155072689056396,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:14,0.6.3 +tiled_swiglu,liger,forward,speed,ms,T,sequence length,8192,5.3361921310424805,5.3361921310424805,5.3361921310424805,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:14,0.6.3 +tiled_swiglu,liger,forward,speed,ms,T,sequence length,16384,10.870783805847168,10.870783805847168,10.870783805847168,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:14,0.6.3 +tiled_swiglu,liger_tiled,forward,speed,ms,T,sequence length,1024,0.7395360469818115,0.7383040189743042,0.7413759827613831,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:15,0.6.3 +tiled_swiglu,liger_tiled,forward,speed,ms,T,sequence length,2048,1.3965599536895752,1.387935996055603,1.4024640321731567,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:15,0.6.3 +tiled_swiglu,liger_tiled,forward,speed,ms,T,sequence length,4096,2.7778561115264893,2.777395248413086,2.7780096530914307,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:15,0.6.3 +tiled_swiglu,liger_tiled,forward,speed,ms,T,sequence length,8192,5.829631805419922,5.829631805419922,5.829631805419922,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:15,0.6.3 +tiled_swiglu,liger_tiled,forward,speed,ms,T,sequence length,16384,11.841535568237305,11.841535568237305,11.841535568237305,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:15,0.6.3 +tiled_swiglu,liger,backward,speed,ms,T,sequence length,1024,1.4970879554748535,1.4961408376693726,1.4970879554748535,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:17,0.6.3 +tiled_swiglu,liger,backward,speed,ms,T,sequence length,2048,3.052351951599121,3.0518529415130615,3.0550782680511475,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:17,0.6.3 +tiled_swiglu,liger,backward,speed,ms,T,sequence length,4096,6.074687957763672,6.074687957763672,6.074687957763672,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:17,0.6.3 +tiled_swiglu,liger,backward,speed,ms,T,sequence length,8192,11.630592346191406,11.630592346191406,11.630592346191406,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:17,0.6.3 +tiled_swiglu,liger,backward,speed,ms,T,sequence length,16384,22.76793670654297,22.76793670654297,22.76793670654297,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:17,0.6.3 +tiled_swiglu,liger_tiled,backward,speed,ms,T,sequence length,1024,2.6021440029144287,2.6000702381134033,2.6032767295837402,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:18,0.6.3 +tiled_swiglu,liger_tiled,backward,speed,ms,T,sequence length,2048,4.641791820526123,4.641791820526123,4.641791820526123,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:18,0.6.3 +tiled_swiglu,liger_tiled,backward,speed,ms,T,sequence length,4096,8.761343955993652,8.761343955993652,8.761343955993652,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:18,0.6.3 +tiled_swiglu,liger_tiled,backward,speed,ms,T,sequence length,8192,17.966079711914062,17.966079711914062,17.966079711914062,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:18,0.6.3 +tiled_swiglu,liger_tiled,backward,speed,ms,T,sequence length,16384,35.657344818115234,35.657344818115234,35.657344818115234,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:18,0.6.3 +tiled_swiglu,liger,full,memory,MB,T,sequence length,1024,232.25,232.25,232.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:21,0.6.3 +tiled_swiglu,liger,full,memory,MB,T,sequence length,2048,336.25,336.25,336.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:21,0.6.3 +tiled_swiglu,liger,full,memory,MB,T,sequence length,4096,544.25,544.25,544.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:21,0.6.3 +tiled_swiglu,liger,full,memory,MB,T,sequence length,8192,960.25,960.25,960.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:21,0.6.3 +tiled_swiglu,liger,full,memory,MB,T,sequence length,16384,1792.25,1792.25,1792.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:21,0.6.3 +tiled_swiglu,liger_tiled,full,memory,MB,T,sequence length,1024,186.25,186.25,186.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:22,0.6.3 +tiled_swiglu,liger_tiled,full,memory,MB,T,sequence length,2048,244.25,244.25,244.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:22,0.6.3 +tiled_swiglu,liger_tiled,full,memory,MB,T,sequence length,4096,360.25,360.25,360.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:22,0.6.3 +tiled_swiglu,liger_tiled,full,memory,MB,T,sequence length,8192,592.25,592.25,592.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:22,0.6.3 +tiled_swiglu,liger_tiled,full,memory,MB,T,sequence length,16384,1056.25,1056.25,1056.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:22,0.6.3 +tiled_swiglu,liger,forward,memory,MB,T,sequence length,1024,128.25,128.25,128.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:25,0.6.3 +tiled_swiglu,liger,forward,memory,MB,T,sequence length,2048,192.25,192.25,192.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:25,0.6.3 +tiled_swiglu,liger,forward,memory,MB,T,sequence length,4096,320.25,320.25,320.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:25,0.6.3 +tiled_swiglu,liger,forward,memory,MB,T,sequence length,8192,576.25,576.25,576.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:25,0.6.3 +tiled_swiglu,liger,forward,memory,MB,T,sequence length,16384,1088.25,1088.25,1088.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:25,0.6.3 +tiled_swiglu,liger_tiled,forward,memory,MB,T,sequence length,1024,92.25,92.25,92.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:25,0.6.3 +tiled_swiglu,liger_tiled,forward,memory,MB,T,sequence length,2048,120.25,120.25,120.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:25,0.6.3 +tiled_swiglu,liger_tiled,forward,memory,MB,T,sequence length,4096,176.25,176.25,176.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:25,0.6.3 +tiled_swiglu,liger_tiled,forward,memory,MB,T,sequence length,8192,288.25,288.25,288.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:25,0.6.3 +tiled_swiglu,liger_tiled,forward,memory,MB,T,sequence length,16384,512.25,512.25,512.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:25,0.6.3 +tiled_swiglu,liger,backward,memory,MB,T,sequence length,1024,232.25,232.25,232.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:27,0.6.3 +tiled_swiglu,liger,backward,memory,MB,T,sequence length,2048,336.25,336.25,336.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:27,0.6.3 +tiled_swiglu,liger,backward,memory,MB,T,sequence length,4096,544.25,544.25,544.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:27,0.6.3 +tiled_swiglu,liger,backward,memory,MB,T,sequence length,8192,960.25,960.25,960.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:27,0.6.3 +tiled_swiglu,liger,backward,memory,MB,T,sequence length,16384,1792.25,1792.25,1792.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:27,0.6.3 +tiled_swiglu,liger_tiled,backward,memory,MB,T,sequence length,1024,186.25,186.25,186.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:28,0.6.3 +tiled_swiglu,liger_tiled,backward,memory,MB,T,sequence length,2048,244.25,244.25,244.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:28,0.6.3 +tiled_swiglu,liger_tiled,backward,memory,MB,T,sequence length,4096,360.25,360.25,360.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:28,0.6.3 +tiled_swiglu,liger_tiled,backward,memory,MB,T,sequence length,8192,592.25,592.25,592.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:28,0.6.3 +tiled_swiglu,liger_tiled,backward,memory,MB,T,sequence length,16384,1056.25,1056.25,1056.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:28,0.6.3 +tiled_geglu,huggingface,full,speed,ms,T,sequence length,1024,2.3357439041137695,2.3357439041137695,2.3375871181488037,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:47,0.6.3 +tiled_geglu,huggingface,full,speed,ms,T,sequence length,2048,4.764671802520752,4.764671802520752,4.764671802520752,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:47,0.6.3 +tiled_geglu,huggingface,full,speed,ms,T,sequence length,4096,9.4236478805542,9.4236478805542,9.4236478805542,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:47,0.6.3 +tiled_geglu,huggingface,full,speed,ms,T,sequence length,8192,17.628543853759766,17.628543853759766,17.628543853759766,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:47,0.6.3 +tiled_geglu,huggingface,full,speed,ms,T,sequence length,16384,35.06790542602539,35.06790542602539,35.06790542602539,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:47,0.6.3 +tiled_geglu,deepspeed_tiled,full,speed,ms,T,sequence length,1024,3.418976068496704,3.4176511764526367,3.4203009605407715,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:51,0.6.3 +tiled_geglu,deepspeed_tiled,full,speed,ms,T,sequence length,2048,6.158143997192383,6.158143997192383,6.158143997192383,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:51,0.6.3 +tiled_geglu,deepspeed_tiled,full,speed,ms,T,sequence length,4096,11.934720039367676,11.934720039367676,11.934720039367676,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:51,0.6.3 +tiled_geglu,deepspeed_tiled,full,speed,ms,T,sequence length,8192,24.731647491455078,24.731647491455078,24.731647491455078,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:51,0.6.3 +tiled_geglu,deepspeed_tiled,full,speed,ms,T,sequence length,16384,49.46227264404297,49.46227264404297,49.46227264404297,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:51,0.6.3 +tiled_geglu,huggingface,forward,speed,ms,T,sequence length,1024,0.6743040084838867,0.6736640334129333,0.677068829536438,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:52,0.6.3 +tiled_geglu,huggingface,forward,speed,ms,T,sequence length,2048,1.418239951133728,1.418239951133728,1.421120047569275,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:52,0.6.3 +tiled_geglu,huggingface,forward,speed,ms,T,sequence length,4096,2.88972806930542,2.889113664627075,2.8909568786621094,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:52,0.6.3 +tiled_geglu,huggingface,forward,speed,ms,T,sequence length,8192,5.701375961303711,5.701375961303711,5.701375961303711,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:52,0.6.3 +tiled_geglu,huggingface,forward,speed,ms,T,sequence length,16384,11.276288032531738,11.276288032531738,11.276288032531738,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:52,0.6.3 +tiled_geglu,deepspeed_tiled,forward,speed,ms,T,sequence length,1024,0.7433919906616211,0.7423999905586243,0.7444480061531067,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:54,0.6.3 +tiled_geglu,deepspeed_tiled,forward,speed,ms,T,sequence length,2048,1.4137760400772095,1.4131200313568115,1.4152319431304932,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:54,0.6.3 +tiled_geglu,deepspeed_tiled,forward,speed,ms,T,sequence length,4096,2.8241920471191406,2.823500871658325,2.8266496658325195,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:54,0.6.3 +tiled_geglu,deepspeed_tiled,forward,speed,ms,T,sequence length,8192,6.087679862976074,6.087679862976074,6.087679862976074,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:54,0.6.3 +tiled_geglu,deepspeed_tiled,forward,speed,ms,T,sequence length,16384,12.353535652160645,12.353535652160645,12.353535652160645,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:54,0.6.3 +tiled_geglu,huggingface,backward,speed,ms,T,sequence length,1024,1.5499199628829956,1.5489535331726074,1.5523840188980103,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:55,0.6.3 +tiled_geglu,huggingface,backward,speed,ms,T,sequence length,2048,3.171328067779541,3.169484853744507,3.173171281814575,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:55,0.6.3 +tiled_geglu,huggingface,backward,speed,ms,T,sequence length,4096,6.263807773590088,6.263807773590088,6.263807773590088,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:55,0.6.3 +tiled_geglu,huggingface,backward,speed,ms,T,sequence length,8192,12.046143531799316,12.046143531799316,12.046143531799316,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:55,0.6.3 +tiled_geglu,huggingface,backward,speed,ms,T,sequence length,16384,23.839744567871094,23.839744567871094,23.839744567871094,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:55,0.6.3 +tiled_geglu,deepspeed_tiled,backward,speed,ms,T,sequence length,1024,2.6757121086120605,2.6755776405334473,2.676710367202759,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:58,0.6.3 +tiled_geglu,deepspeed_tiled,backward,speed,ms,T,sequence length,2048,4.7329277992248535,4.7329277992248535,4.7329277992248535,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:58,0.6.3 +tiled_geglu,deepspeed_tiled,backward,speed,ms,T,sequence length,4096,9.078783988952637,9.078783988952637,9.078783988952637,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:58,0.6.3 +tiled_geglu,deepspeed_tiled,backward,speed,ms,T,sequence length,8192,18.63680076599121,18.63680076599121,18.63680076599121,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:58,0.6.3 +tiled_geglu,deepspeed_tiled,backward,speed,ms,T,sequence length,16384,37.06163024902344,37.06163024902344,37.06163024902344,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:58,0.6.3 +tiled_geglu,huggingface,full,memory,MB,T,sequence length,1024,264.25,264.25,264.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:59,0.6.3 +tiled_geglu,huggingface,full,memory,MB,T,sequence length,2048,400.25,400.25,400.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:59,0.6.3 +tiled_geglu,huggingface,full,memory,MB,T,sequence length,4096,688.25,688.25,688.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:59,0.6.3 +tiled_geglu,huggingface,full,memory,MB,T,sequence length,8192,1264.25,1264.25,1264.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:59,0.6.3 +tiled_geglu,huggingface,full,memory,MB,T,sequence length,16384,2416.25,2416.25,2416.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:22:59,0.6.3 +tiled_geglu,deepspeed_tiled,full,memory,MB,T,sequence length,1024,190.25,190.25,190.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:02,0.6.3 +tiled_geglu,deepspeed_tiled,full,memory,MB,T,sequence length,2048,252.25,252.25,252.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:02,0.6.3 +tiled_geglu,deepspeed_tiled,full,memory,MB,T,sequence length,4096,376.25,376.25,376.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:02,0.6.3 +tiled_geglu,deepspeed_tiled,full,memory,MB,T,sequence length,8192,640.25,640.25,640.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:02,0.6.3 +tiled_geglu,deepspeed_tiled,full,memory,MB,T,sequence length,16384,1168.25,1168.25,1168.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:02,0.6.3 +tiled_geglu,huggingface,forward,memory,MB,T,sequence length,1024,144.25,144.25,144.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:02,0.6.3 +tiled_geglu,huggingface,forward,memory,MB,T,sequence length,2048,224.25,224.25,224.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:02,0.6.3 +tiled_geglu,huggingface,forward,memory,MB,T,sequence length,4096,384.25,384.25,384.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:02,0.6.3 +tiled_geglu,huggingface,forward,memory,MB,T,sequence length,8192,704.25,704.25,704.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:02,0.6.3 +tiled_geglu,huggingface,forward,memory,MB,T,sequence length,16384,1344.25,1344.25,1344.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:02,0.6.3 +tiled_geglu,deepspeed_tiled,forward,memory,MB,T,sequence length,1024,90.25,90.25,90.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:04,0.6.3 +tiled_geglu,deepspeed_tiled,forward,memory,MB,T,sequence length,2048,116.25,116.25,116.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:04,0.6.3 +tiled_geglu,deepspeed_tiled,forward,memory,MB,T,sequence length,4096,168.25,168.25,168.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:04,0.6.3 +tiled_geglu,deepspeed_tiled,forward,memory,MB,T,sequence length,8192,272.25,272.25,272.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:04,0.6.3 +tiled_geglu,deepspeed_tiled,forward,memory,MB,T,sequence length,16384,480.25,480.25,480.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:04,0.6.3 +tiled_geglu,huggingface,backward,memory,MB,T,sequence length,1024,264.25,264.25,264.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:05,0.6.3 +tiled_geglu,huggingface,backward,memory,MB,T,sequence length,2048,400.25,400.25,400.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:05,0.6.3 +tiled_geglu,huggingface,backward,memory,MB,T,sequence length,4096,688.25,688.25,688.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:05,0.6.3 +tiled_geglu,huggingface,backward,memory,MB,T,sequence length,8192,1264.25,1264.25,1264.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:05,0.6.3 +tiled_geglu,huggingface,backward,memory,MB,T,sequence length,16384,2416.25,2416.25,2416.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:05,0.6.3 +tiled_geglu,deepspeed_tiled,backward,memory,MB,T,sequence length,1024,190.25,190.25,190.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:07,0.6.3 +tiled_geglu,deepspeed_tiled,backward,memory,MB,T,sequence length,2048,252.25,252.25,252.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:07,0.6.3 +tiled_geglu,deepspeed_tiled,backward,memory,MB,T,sequence length,4096,376.25,376.25,376.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:07,0.6.3 +tiled_geglu,deepspeed_tiled,backward,memory,MB,T,sequence length,8192,640.25,640.25,640.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:07,0.6.3 +tiled_geglu,deepspeed_tiled,backward,memory,MB,T,sequence length,16384,1168.25,1168.25,1168.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""gelu_pytorch_tanh"", ""activation_type"": ""geglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:07,0.6.3 +tiled_swiglu,huggingface,full,speed,ms,T,sequence length,1024,2.2517759799957275,2.2517759799957275,2.254848003387451,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:08,0.6.3 +tiled_swiglu,huggingface,full,speed,ms,T,sequence length,2048,4.588511943817139,4.587302207946777,4.5897216796875,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:08,0.6.3 +tiled_swiglu,huggingface,full,speed,ms,T,sequence length,4096,9.233407974243164,9.233407974243164,9.233407974243164,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:08,0.6.3 +tiled_swiglu,huggingface,full,speed,ms,T,sequence length,8192,17.869823455810547,17.869823455810547,17.869823455810547,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:08,0.6.3 +tiled_swiglu,huggingface,full,speed,ms,T,sequence length,16384,35.34422302246094,35.34422302246094,35.34422302246094,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:08,0.6.3 +tiled_swiglu,deepspeed_tiled,full,speed,ms,T,sequence length,1024,3.4257922172546387,3.424870491027832,3.426713705062866,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:12,0.6.3 +tiled_swiglu,deepspeed_tiled,full,speed,ms,T,sequence length,2048,6.155263900756836,6.155263900756836,6.155263900756836,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:12,0.6.3 +tiled_swiglu,deepspeed_tiled,full,speed,ms,T,sequence length,4096,11.92959976196289,11.92959976196289,11.92959976196289,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:12,0.6.3 +tiled_swiglu,deepspeed_tiled,full,speed,ms,T,sequence length,8192,24.815616607666016,24.815616607666016,24.815616607666016,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:12,0.6.3 +tiled_swiglu,deepspeed_tiled,full,speed,ms,T,sequence length,16384,49.62918472290039,49.62918472290039,49.62918472290039,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:12,0.6.3 +tiled_swiglu,huggingface,forward,speed,ms,T,sequence length,1024,0.6748160123825073,0.6737920045852661,0.6758400201797485,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:13,0.6.3 +tiled_swiglu,huggingface,forward,speed,ms,T,sequence length,2048,1.4332799911499023,1.4325759410858154,1.4335999488830566,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:13,0.6.3 +tiled_swiglu,huggingface,forward,speed,ms,T,sequence length,4096,2.91212797164917,2.904217481613159,2.9146623611450195,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:13,0.6.3 +tiled_swiglu,huggingface,forward,speed,ms,T,sequence length,8192,5.658976078033447,5.658976078033447,5.658976078033447,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:13,0.6.3 +tiled_swiglu,huggingface,forward,speed,ms,T,sequence length,16384,11.341952323913574,11.341952323913574,11.341952323913574,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:13,0.6.3 +tiled_swiglu,deepspeed_tiled,forward,speed,ms,T,sequence length,1024,0.7454720139503479,0.7429631948471069,0.7456768155097961,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:15,0.6.3 +tiled_swiglu,deepspeed_tiled,forward,speed,ms,T,sequence length,2048,1.4120960235595703,1.410048007965088,1.4120960235595703,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:15,0.6.3 +tiled_swiglu,deepspeed_tiled,forward,speed,ms,T,sequence length,4096,2.825216054916382,2.825216054916382,2.8264448642730713,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:15,0.6.3 +tiled_swiglu,deepspeed_tiled,forward,speed,ms,T,sequence length,8192,6.077439785003662,6.077439785003662,6.077439785003662,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:15,0.6.3 +tiled_swiglu,deepspeed_tiled,forward,speed,ms,T,sequence length,16384,12.356608390808105,12.356608390808105,12.356608390808105,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:15,0.6.3 +tiled_swiglu,huggingface,backward,speed,ms,T,sequence length,1024,1.551360011100769,1.5511807203292847,1.5532032251358032,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:16,0.6.3 +tiled_swiglu,huggingface,backward,speed,ms,T,sequence length,2048,3.1928319931030273,3.1885311603546143,3.1971328258514404,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:16,0.6.3 +tiled_swiglu,huggingface,backward,speed,ms,T,sequence length,4096,6.273248195648193,6.273248195648193,6.273248195648193,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:16,0.6.3 +tiled_swiglu,huggingface,backward,speed,ms,T,sequence length,8192,12.058752059936523,12.058752059936523,12.058752059936523,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:16,0.6.3 +tiled_swiglu,huggingface,backward,speed,ms,T,sequence length,16384,23.853055953979492,23.853055953979492,23.853055953979492,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:16,0.6.3 +tiled_swiglu,deepspeed_tiled,backward,speed,ms,T,sequence length,1024,2.6746881008148193,2.6728639602661133,2.6789886951446533,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:20,0.6.3 +tiled_swiglu,deepspeed_tiled,backward,speed,ms,T,sequence length,2048,4.739071846008301,4.739071846008301,4.739071846008301,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:20,0.6.3 +tiled_swiglu,deepspeed_tiled,backward,speed,ms,T,sequence length,4096,9.084927558898926,9.084927558898926,9.084927558898926,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:20,0.6.3 +tiled_swiglu,deepspeed_tiled,backward,speed,ms,T,sequence length,8192,18.729759216308594,18.729759216308594,18.729759216308594,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:20,0.6.3 +tiled_swiglu,deepspeed_tiled,backward,speed,ms,T,sequence length,16384,37.13724899291992,37.13724899291992,37.13724899291992,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:20,0.6.3 +tiled_swiglu,huggingface,full,memory,MB,T,sequence length,1024,264.25,264.25,264.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:20,0.6.3 +tiled_swiglu,huggingface,full,memory,MB,T,sequence length,2048,400.25,400.25,400.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:20,0.6.3 +tiled_swiglu,huggingface,full,memory,MB,T,sequence length,4096,688.25,688.25,688.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:20,0.6.3 +tiled_swiglu,huggingface,full,memory,MB,T,sequence length,8192,1264.25,1264.25,1264.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:20,0.6.3 +tiled_swiglu,huggingface,full,memory,MB,T,sequence length,16384,2416.25,2416.25,2416.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:20,0.6.3 +tiled_swiglu,deepspeed_tiled,full,memory,MB,T,sequence length,1024,190.25,190.25,190.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:23,0.6.3 +tiled_swiglu,deepspeed_tiled,full,memory,MB,T,sequence length,2048,252.25,252.25,252.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:23,0.6.3 +tiled_swiglu,deepspeed_tiled,full,memory,MB,T,sequence length,4096,376.25,376.25,376.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:23,0.6.3 +tiled_swiglu,deepspeed_tiled,full,memory,MB,T,sequence length,8192,640.25,640.25,640.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:23,0.6.3 +tiled_swiglu,deepspeed_tiled,full,memory,MB,T,sequence length,16384,1168.25,1168.25,1168.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:23,0.6.3 +tiled_swiglu,huggingface,forward,memory,MB,T,sequence length,1024,144.25,144.25,144.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:24,0.6.3 +tiled_swiglu,huggingface,forward,memory,MB,T,sequence length,2048,224.25,224.25,224.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:24,0.6.3 +tiled_swiglu,huggingface,forward,memory,MB,T,sequence length,4096,384.25,384.25,384.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:24,0.6.3 +tiled_swiglu,huggingface,forward,memory,MB,T,sequence length,8192,704.25,704.25,704.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:24,0.6.3 +tiled_swiglu,huggingface,forward,memory,MB,T,sequence length,16384,1344.25,1344.25,1344.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:24,0.6.3 +tiled_swiglu,deepspeed_tiled,forward,memory,MB,T,sequence length,1024,90.25,90.25,90.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:26,0.6.3 +tiled_swiglu,deepspeed_tiled,forward,memory,MB,T,sequence length,2048,116.25,116.25,116.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:26,0.6.3 +tiled_swiglu,deepspeed_tiled,forward,memory,MB,T,sequence length,4096,168.25,168.25,168.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:26,0.6.3 +tiled_swiglu,deepspeed_tiled,forward,memory,MB,T,sequence length,8192,272.25,272.25,272.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:26,0.6.3 +tiled_swiglu,deepspeed_tiled,forward,memory,MB,T,sequence length,16384,480.25,480.25,480.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:26,0.6.3 +tiled_swiglu,huggingface,backward,memory,MB,T,sequence length,1024,264.25,264.25,264.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:26,0.6.3 +tiled_swiglu,huggingface,backward,memory,MB,T,sequence length,2048,400.25,400.25,400.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:26,0.6.3 +tiled_swiglu,huggingface,backward,memory,MB,T,sequence length,4096,688.25,688.25,688.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:26,0.6.3 +tiled_swiglu,huggingface,backward,memory,MB,T,sequence length,8192,1264.25,1264.25,1264.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:26,0.6.3 +tiled_swiglu,huggingface,backward,memory,MB,T,sequence length,16384,2416.25,2416.25,2416.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:26,0.6.3 +tiled_swiglu,deepspeed_tiled,backward,memory,MB,T,sequence length,1024,190.25,190.25,190.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:29,0.6.3 +tiled_swiglu,deepspeed_tiled,backward,memory,MB,T,sequence length,2048,252.25,252.25,252.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:29,0.6.3 +tiled_swiglu,deepspeed_tiled,backward,memory,MB,T,sequence length,4096,376.25,376.25,376.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:29,0.6.3 +tiled_swiglu,deepspeed_tiled,backward,memory,MB,T,sequence length,8192,640.25,640.25,640.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:29,0.6.3 +tiled_swiglu,deepspeed_tiled,backward,memory,MB,T,sequence length,16384,1168.25,1168.25,1168.25,"{""bsz"": 2, ""hidden_size"": 2048, ""intermediate_size"": 4096, ""hidden_act"": ""silu"", ""activation_type"": ""swiglu"", ""num_shards"": 4, ""dtype"": ""torch.bfloat16""}",NVIDIA GeForce RTX 4090,2025-11-11 06:23:29,0.6.3 +tvd,liger,full,memory,MB,V,vocab size,4096,1792.0009765625,1792.0009765625,1792.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:48,0.7.0 +tvd,liger,full,memory,MB,V,vocab size,8192,3584.0009765625,3584.0009765625,3584.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:48,0.7.0 +tvd,liger,full,memory,MB,V,vocab size,16384,7168.0009765625,7168.0009765625,7168.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:48,0.7.0 +tvd,liger,full,memory,MB,V,vocab size,32768,14336.0009765625,14336.0009765625,14336.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:48,0.7.0 +tvd,liger,full,memory,MB,V,vocab size,65536,28672.0,28672.0,28672.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:48,0.7.0 +tvd,liger,full,memory,MB,V,vocab size,131072,57344.0,57344.0,57344.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:48,0.7.0 +tvd,torch,full,memory,MB,V,vocab size,4096,2048.0009765625,2048.0009765625,2048.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:48,0.7.0 +tvd,torch,full,memory,MB,V,vocab size,8192,4096.0009765625,4096.0009765625,4096.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:48,0.7.0 +tvd,torch,full,memory,MB,V,vocab size,16384,8192.0009765625,8192.0009765625,8192.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:48,0.7.0 +tvd,torch,full,memory,MB,V,vocab size,32768,16384.0,16384.0,16384.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:48,0.7.0 +tvd,torch,full,memory,MB,V,vocab size,65536,32768.0,32768.0,32768.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:48,0.7.0 +tvd,torch,full,memory,MB,V,vocab size,131072,65536.0,65536.0,65536.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:48,0.7.0 +tvd,liger,forward,speed,ms,V,vocab size,4096,0.2757120132446289,0.27487359642982484,0.27616640329360964,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:50,0.7.0 +tvd,liger,forward,speed,ms,V,vocab size,8192,0.5338559746742249,0.5333759784698486,0.5346879959106445,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:50,0.7.0 +tvd,liger,forward,speed,ms,V,vocab size,16384,1.0511679649353027,1.0505280494689941,1.0521472215652465,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:50,0.7.0 +tvd,liger,forward,speed,ms,V,vocab size,32768,2.0986878871917725,2.09736967086792,2.0999168872833254,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:50,0.7.0 +tvd,liger,forward,speed,ms,V,vocab size,65536,4.221951961517334,4.22039680480957,4.222847938537598,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:50,0.7.0 +tvd,liger,forward,speed,ms,V,vocab size,131072,8.501215934753418,8.498592376708984,8.50380802154541,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:50,0.7.0 +tvd,torch,forward,speed,ms,V,vocab size,4096,0.7288320064544678,0.727942419052124,0.7296640276908875,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:51,0.7.0 +tvd,torch,forward,speed,ms,V,vocab size,8192,1.4264639616012573,1.42576003074646,1.4272960424423218,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:51,0.7.0 +tvd,torch,forward,speed,ms,V,vocab size,16384,2.81440007686615,2.8132031917572022,2.815097618103027,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:51,0.7.0 +tvd,torch,forward,speed,ms,V,vocab size,32768,5.5965118408203125,5.59548807144165,5.598131275177002,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:51,0.7.0 +tvd,torch,forward,speed,ms,V,vocab size,65536,11.178752422332764,11.176428604125977,11.180454635620118,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:51,0.7.0 +tvd,torch,forward,speed,ms,V,vocab size,131072,22.33670425415039,22.334880065917968,22.339027404785156,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:51,0.7.0 +tvd,liger,full,speed,ms,V,vocab size,4096,1.123952031135559,1.1221888303756713,1.1291328191757202,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:53,0.7.0 +tvd,liger,full,speed,ms,V,vocab size,8192,2.1660319566726685,2.162835216522217,2.169088077545166,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:53,0.7.0 +tvd,liger,full,speed,ms,V,vocab size,16384,4.563424110412598,4.559807777404785,4.5669121742248535,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:53,0.7.0 +tvd,liger,full,speed,ms,V,vocab size,32768,9.092079639434814,9.089529991149902,9.094182014465332,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:53,0.7.0 +tvd,liger,full,speed,ms,V,vocab size,65536,18.217248916625977,18.20675277709961,18.219014739990236,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:53,0.7.0 +tvd,liger,full,speed,ms,V,vocab size,131072,36.477935791015625,36.46965026855469,36.48622131347656,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:53,0.7.0 +tvd,torch,full,speed,ms,V,vocab size,4096,2.1256959438323975,2.1249279975891113,2.1270463466644287,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:54,0.7.0 +tvd,torch,full,speed,ms,V,vocab size,8192,4.191232204437256,4.189510250091553,4.192793464660644,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:54,0.7.0 +tvd,torch,full,speed,ms,V,vocab size,16384,8.638431549072266,8.636992454528809,8.639007568359375,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:54,0.7.0 +tvd,torch,full,speed,ms,V,vocab size,32768,17.25654411315918,17.25450286865234,17.25882225036621,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:54,0.7.0 +tvd,torch,full,speed,ms,V,vocab size,65536,34.54822540283203,34.546746826171876,34.549703979492186,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:54,0.7.0 +tvd,torch,full,speed,ms,V,vocab size,131072,69.17910766601562,69.17910766601562,69.17910766601562,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:54,0.7.0 +tvd,liger,backward,speed,ms,V,vocab size,4096,0.8502079844474792,0.8484799861907959,0.8526080250740051,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:55,0.7.0 +tvd,liger,backward,speed,ms,V,vocab size,8192,1.6321280002593994,1.629702377319336,1.6350399732589722,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:55,0.7.0 +tvd,liger,backward,speed,ms,V,vocab size,16384,3.5109760761260986,3.5084415912628173,3.513107109069824,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:55,0.7.0 +tvd,liger,backward,speed,ms,V,vocab size,32768,6.989071846008301,6.985472011566161,6.994240188598633,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:55,0.7.0 +tvd,liger,backward,speed,ms,V,vocab size,65536,13.969247817993164,13.95904598236084,13.971328163146971,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:55,0.7.0 +tvd,liger,backward,speed,ms,V,vocab size,131072,27.982528686523438,27.963673400878903,27.987577819824217,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:55,0.7.0 +tvd,torch,backward,speed,ms,V,vocab size,4096,1.398911952972412,1.3979583740234376,1.4000320434570312,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:57,0.7.0 +tvd,torch,backward,speed,ms,V,vocab size,8192,2.7701759338378906,2.7694976329803467,2.7718528747558593,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:57,0.7.0 +tvd,torch,backward,speed,ms,V,vocab size,16384,5.828160047531128,5.8249921798706055,5.829792022705078,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:57,0.7.0 +tvd,torch,backward,speed,ms,V,vocab size,32768,11.665760040283203,11.664883232116699,11.666317176818847,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:57,0.7.0 +tvd,torch,backward,speed,ms,V,vocab size,65536,23.379840850830078,23.37938575744629,23.381267929077147,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:57,0.7.0 +tvd,torch,backward,speed,ms,V,vocab size,131072,46.83844757080078,46.8328125,46.84408264160156,"{""B"": 8, ""T"": 2048}",NVIDIA H100 80GB HBM3,2026-03-03 23:02:57,0.7.0 +group_norm,liger,forward,speed,ms,C,num_channels,32,0.017535999417304993,0.016863999888300896,0.01833599992096424,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:15,0.7.0 +group_norm,liger,forward,speed,ms,C,num_channels,64,0.018848000094294548,0.018015999346971512,0.019487999379634857,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:15,0.7.0 +group_norm,liger,forward,speed,ms,C,num_channels,128,0.026623999699950218,0.024607999250292778,0.026688000187277794,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:15,0.7.0 +group_norm,liger,forward,speed,ms,C,num_channels,256,0.038943998515605927,0.03888000175356865,0.03903999924659729,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:15,0.7.0 +group_norm,liger,forward,speed,ms,C,num_channels,512,0.06351999938488007,0.06345599889755249,0.06550399959087372,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:15,0.7.0 +group_norm,liger,forward,speed,ms,C,num_channels,1024,0.11475200206041336,0.11468800157308578,0.11673600226640701,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:15,0.7.0 +group_norm,liger,forward,speed,ms,C,num_channels,2048,0.21910400688648224,0.217056006193161,0.22115199267864227,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:15,0.7.0 +group_norm,huggingface,forward,speed,ms,C,num_channels,32,0.030688000842928886,0.030592000111937523,0.030751999467611313,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:19,0.7.0 +group_norm,huggingface,forward,speed,ms,C,num_channels,64,0.043007999658584595,0.04294399917125702,0.04303999990224838,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:19,0.7.0 +group_norm,huggingface,forward,speed,ms,C,num_channels,128,0.07168000191450119,0.07161600142717361,0.07174400240182877,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:19,0.7.0 +group_norm,huggingface,forward,speed,ms,C,num_channels,256,0.13516800105571747,0.1351040005683899,0.13523200154304504,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:19,0.7.0 +group_norm,huggingface,forward,speed,ms,C,num_channels,512,0.25808000564575195,0.2580159902572632,0.25900799036026,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:19,0.7.0 +group_norm,huggingface,forward,speed,ms,C,num_channels,1024,0.4986239969730377,0.4976640045642853,0.4997439980506897,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:19,0.7.0 +group_norm,huggingface,forward,speed,ms,C,num_channels,2048,0.9819360077381134,0.9800639748573303,0.9830080270767212,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:19,0.7.0 +group_norm,liger,full,speed,ms,C,num_channels,32,0.1658720001578331,0.16368000209331512,0.16958080232143402,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:22,0.7.0 +group_norm,liger,full,speed,ms,C,num_channels,64,0.1730239987373352,0.17123199999332428,0.17520000040531158,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:22,0.7.0 +group_norm,liger,full,speed,ms,C,num_channels,128,0.1695999950170517,0.16783360242843628,0.1717183977365494,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:22,0.7.0 +group_norm,liger,full,speed,ms,C,num_channels,256,0.174112007021904,0.17206400632858276,0.17718400061130524,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:22,0.7.0 +group_norm,liger,full,speed,ms,C,num_channels,512,0.18745599687099457,0.18636800348758698,0.18848000466823578,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:22,0.7.0 +group_norm,liger,full,speed,ms,C,num_channels,1024,0.3388479948043823,0.33792001008987427,0.3400000035762787,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:22,0.7.0 +group_norm,liger,full,speed,ms,C,num_channels,2048,0.6390079855918884,0.6371200084686279,0.6410560011863708,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:22,0.7.0 +group_norm,huggingface,full,speed,ms,C,num_channels,32,0.08396799862384796,0.08390399813652039,0.08403199911117554,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:26,0.7.0 +group_norm,huggingface,full,speed,ms,C,num_channels,64,0.11267200112342834,0.11260800063610077,0.1128000020980835,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:26,0.7.0 +group_norm,huggingface,full,speed,ms,C,num_channels,128,0.20054399967193604,0.19868800044059753,0.20080000162124634,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:26,0.7.0 +group_norm,huggingface,full,speed,ms,C,num_channels,256,0.35020801424980164,0.34828799962997437,0.3511039912700653,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:26,0.7.0 +group_norm,huggingface,full,speed,ms,C,num_channels,512,0.6307839751243591,0.6297919750213623,0.6309120059013367,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:26,0.7.0 +group_norm,huggingface,full,speed,ms,C,num_channels,1024,1.177664041519165,1.1766079664230347,1.1796480417251587,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:26,0.7.0 +group_norm,huggingface,full,speed,ms,C,num_channels,2048,2.2947518825531006,2.292736053466797,2.296736001968384,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:26,0.7.0 +group_norm,liger,backward,speed,ms,C,num_channels,32,0.06643199920654297,0.0655359998345375,0.06752000004053116,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:28,0.7.0 +group_norm,liger,backward,speed,ms,C,num_channels,64,0.06732799857854843,0.0663679987192154,0.06838399916887283,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:28,0.7.0 +group_norm,liger,backward,speed,ms,C,num_channels,128,0.07171200215816498,0.06969600170850754,0.07273600250482559,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:28,0.7.0 +group_norm,liger,backward,speed,ms,C,num_channels,256,0.07580800354480743,0.07571200281381607,0.07683199644088745,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:28,0.7.0 +group_norm,liger,backward,speed,ms,C,num_channels,512,0.12697599828243256,0.1249919980764389,0.12703999876976013,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:28,0.7.0 +group_norm,liger,backward,speed,ms,C,num_channels,1024,0.2253440022468567,0.2252800017595291,0.22729599475860596,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:28,0.7.0 +group_norm,liger,backward,speed,ms,C,num_channels,2048,0.42585599422454834,0.42396798729896545,0.4260160028934479,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:28,0.7.0 +group_norm,huggingface,backward,speed,ms,C,num_channels,32,0.05532800033688545,0.05526399984955788,0.056352000683546066,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:31,0.7.0 +group_norm,huggingface,backward,speed,ms,C,num_channels,64,0.07372800260782242,0.07171200215816498,0.0739263966679573,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:31,0.7.0 +group_norm,huggingface,backward,speed,ms,C,num_channels,128,0.13315199315547943,0.13308799266815186,0.13331200182437897,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:31,0.7.0 +group_norm,huggingface,backward,speed,ms,C,num_channels,256,0.21916800737380981,0.21904000639915466,0.21926400065422058,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:31,0.7.0 +group_norm,huggingface,backward,speed,ms,C,num_channels,512,0.374783992767334,0.37379199266433716,0.37484800815582275,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:31,0.7.0 +group_norm,huggingface,backward,speed,ms,C,num_channels,1024,0.6820799708366394,0.6810240149497986,0.6839039921760559,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:31,0.7.0 +group_norm,huggingface,backward,speed,ms,C,num_channels,2048,1.3158719539642334,1.3157440423965454,1.3177599906921387,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:31,0.7.0 +group_norm,liger,full,memory,MB,C,num_channels,32,40.01171875,40.01171875,40.01171875,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:31,0.7.0 +group_norm,liger,full,memory,MB,C,num_channels,64,80.01953125,80.01953125,80.01953125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:31,0.7.0 +group_norm,liger,full,memory,MB,C,num_channels,128,160.03515625,160.03515625,160.03515625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:31,0.7.0 +group_norm,liger,full,memory,MB,C,num_channels,256,320.0703125,320.0703125,320.0703125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:31,0.7.0 +group_norm,liger,full,memory,MB,C,num_channels,512,640.140625,640.140625,640.140625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:31,0.7.0 +group_norm,liger,full,memory,MB,C,num_channels,1024,1280.28125,1280.28125,1280.28125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:31,0.7.0 +group_norm,liger,full,memory,MB,C,num_channels,2048,2560.5625,2560.5625,2560.5625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:31,0.7.0 +group_norm,huggingface,full,memory,MB,C,num_channels,32,40.06640625,40.06640625,40.06640625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,huggingface,full,memory,MB,C,num_channels,64,80.12890625,80.12890625,80.12890625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,huggingface,full,memory,MB,C,num_channels,128,160.25390625,160.25390625,160.25390625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,huggingface,full,memory,MB,C,num_channels,256,320.5078125,320.5078125,320.5078125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,huggingface,full,memory,MB,C,num_channels,512,641.015625,641.015625,641.015625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,huggingface,full,memory,MB,C,num_channels,1024,1282.03125,1282.03125,1282.03125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,huggingface,full,memory,MB,C,num_channels,2048,2564.0625,2564.0625,2564.0625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,liger,forward,memory,MB,C,num_channels,32,40.01171875,40.01171875,40.01171875,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,liger,forward,memory,MB,C,num_channels,64,80.01953125,80.01953125,80.01953125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,liger,forward,memory,MB,C,num_channels,128,160.03515625,160.03515625,160.03515625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,liger,forward,memory,MB,C,num_channels,256,320.0703125,320.0703125,320.0703125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,liger,forward,memory,MB,C,num_channels,512,640.140625,640.140625,640.140625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,liger,forward,memory,MB,C,num_channels,1024,1280.28125,1280.28125,1280.28125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,liger,forward,memory,MB,C,num_channels,2048,2560.5625,2560.5625,2560.5625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,huggingface,forward,memory,MB,C,num_channels,32,40.06640625,40.06640625,40.06640625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,huggingface,forward,memory,MB,C,num_channels,64,80.12890625,80.12890625,80.12890625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,huggingface,forward,memory,MB,C,num_channels,128,160.25390625,160.25390625,160.25390625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,huggingface,forward,memory,MB,C,num_channels,256,320.5078125,320.5078125,320.5078125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,huggingface,forward,memory,MB,C,num_channels,512,641.015625,641.015625,641.015625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,huggingface,forward,memory,MB,C,num_channels,1024,1282.03125,1282.03125,1282.03125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,huggingface,forward,memory,MB,C,num_channels,2048,2564.0625,2564.0625,2564.0625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,liger,backward,memory,MB,C,num_channels,32,40.01171875,40.01171875,40.01171875,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,liger,backward,memory,MB,C,num_channels,64,80.01953125,80.01953125,80.01953125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,liger,backward,memory,MB,C,num_channels,128,160.03515625,160.03515625,160.03515625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,liger,backward,memory,MB,C,num_channels,256,320.0703125,320.0703125,320.0703125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,liger,backward,memory,MB,C,num_channels,512,640.140625,640.140625,640.140625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,liger,backward,memory,MB,C,num_channels,1024,1280.28125,1280.28125,1280.28125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,liger,backward,memory,MB,C,num_channels,2048,2560.5625,2560.5625,2560.5625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,huggingface,backward,memory,MB,C,num_channels,32,40.06640625,40.06640625,40.06640625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,huggingface,backward,memory,MB,C,num_channels,64,80.12890625,80.12890625,80.12890625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,huggingface,backward,memory,MB,C,num_channels,128,160.25390625,160.25390625,160.25390625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,huggingface,backward,memory,MB,C,num_channels,256,320.5078125,320.5078125,320.5078125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,huggingface,backward,memory,MB,C,num_channels,512,641.015625,641.015625,641.015625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,huggingface,backward,memory,MB,C,num_channels,1024,1282.03125,1282.03125,1282.03125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 +group_norm,huggingface,backward,memory,MB,C,num_channels,2048,2564.0625,2564.0625,2564.0625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA B200,2026-02-28 00:23:32,0.7.0 diff --git a/benchmark/scripts/__init__.py b/benchmark/scripts/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/benchmark/scripts/benchmark_cpo_loss.py b/benchmark/scripts/benchmark_cpo_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..8b10d518880b19644bd7d6c3cc4b9cd64cc8a541 --- /dev/null +++ b/benchmark/scripts/benchmark_cpo_loss.py @@ -0,0 +1,167 @@ +import os +import sys + +import torch +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.utils import infer_device + +device = infer_device() + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + + +############################################################################# +# Test the memory consumption of the linear fused cross entropy loss +############################################################################# + + +def bench_memory_fused_linear_cpo_loss( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + from test.chunked_loss.test_cpo_loss import LigerLMHeadCPO + from test.chunked_loss.test_cpo_loss import TorchLMHeadCPO + + B = input.x + T = input.extra_benchmark_config["T"] + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + provider = input.kernel_provider + + # Instantiate once and retrieve the first output only + torch_lm_head_cpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device) + liger_lm_head_cpo = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device) + torch_fwd = lambda x, target: torch_lm_head_cpo(x, target)[0] + liger_fwd = lambda x, target: liger_lm_head_cpo(x, target)[0] + + _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) + target = torch.randint(V, (B, T), dtype=torch.long, device=device) + + def fwd(): + if provider == "liger": + return liger_fwd(_input, target) + elif provider == "huggingface": + return torch_fwd(_input, target) + + def full(): + y = fwd() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +# ############################################################################# +# # Test the speed of the fused linear cross entropy loss +# ############################################################################# + + +def bench_speed_fused_linear_cpo_loss( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + from test.chunked_loss.test_cpo_loss import LigerLMHeadCPO + from test.chunked_loss.test_cpo_loss import TorchLMHeadCPO + + B = input.x + T = input.extra_benchmark_config["T"] + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + provider = input.kernel_provider + mode = input.kernel_operation_mode + + # Instantiate once and retrieve the first output only + torch_lm_head_cpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device) + liger_lm_head_cpo = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device) + torch_fwd = lambda x, target: torch_lm_head_cpo(x, target)[0] + liger_fwd = lambda x, target: liger_lm_head_cpo(x, target)[0] + + _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) + target = torch.randint(V, (B, T), dtype=torch.long, device=device) + + def fwd(): + if provider == "liger": + return liger_fwd(_input, target) + elif provider == "huggingface": + return torch_fwd(_input, target) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = fwd() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[_input], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + rep=100, + quantiles=QUANTILES, + ) + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "fused_linear_cpo_loss", + "x_name": "B", + "x_label": "B", + "x_values": [2**i for i in range(1, 5)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "T": 1024, + "H": 4096, + "V": 128256, + "mode": "forward", + "dtype": torch.bfloat16, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_fused_linear_cpo_loss, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_fused_linear_cpo_loss, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_cross_entropy.py b/benchmark/scripts/benchmark_cross_entropy.py new file mode 100755 index 0000000000000000000000000000000000000000..cdd61814ac076923a4d75b2eeef0866c8d70f081 --- /dev/null +++ b/benchmark/scripts/benchmark_cross_entropy.py @@ -0,0 +1,126 @@ +import torch +import triton + +from torch.nn import CrossEntropyLoss +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss +from liger_kernel.utils import infer_device + +device = infer_device() + + +def bench_memory_cross_entropy( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + torch_ce = CrossEntropyLoss() + liger_ce = LigerCrossEntropyLoss() + + V = input.x + provider = input.kernel_provider + B = input.extra_benchmark_config["B"] + T = input.extra_benchmark_config["T"] + + _input = torch.randn(B * T, V, requires_grad=True, device=device) + target = torch.randint(V, (B * T, 1), device=device).squeeze(1) + + def fwd(): + if provider == "liger": + return liger_ce(_input, target) + else: + return torch_ce(_input, target) + + def full(): + y = fwd() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +def bench_speed_cross_entropy( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + torch_ce = CrossEntropyLoss() + liger_ce = LigerCrossEntropyLoss() + + V = input.x + provider = input.kernel_provider + mode = input.kernel_operation_mode + B = input.extra_benchmark_config["B"] + T = input.extra_benchmark_config["T"] + + _input = torch.randn(B * T, V, requires_grad=True, device=device) + target = torch.randint(V, (B * T, 1), device=device).squeeze(1) + + def fwd(): + if provider == "liger": + return liger_ce(_input, target) + else: + return torch_ce(_input, target) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) + elif mode == "no-grad-forward": + with torch.no_grad(): + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) + elif mode == "backward": + y = fwd() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[_input], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) + + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "cross_entropy", + "x_name": "V", + "x_label": "vocab size", + "x_values": [2**i for i in range(12, 18)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [{"B": 8, "T": 2048}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_cross_entropy, + kernel_operation_modes=["forward", "backward", "full", "no-grad-forward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_cross_entropy, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_distill_cosine_loss.py b/benchmark/scripts/benchmark_distill_cosine_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..5cf12b495c835d507c5f11c7f8078b2e27414354 --- /dev/null +++ b/benchmark/scripts/benchmark_distill_cosine_loss.py @@ -0,0 +1,266 @@ +import os +import sys + +import torch +import torch.nn as nn +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityFunction +from liger_kernel.utils import infer_device + +device = infer_device() + +# Ensure the project root is in the path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + + +class TorchCosineSimilarityLoss(nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + bias: bool = False, + ): + from test.chunked_loss.test_cosine_loss import HFCosineLoss + + super().__init__() + self.student_lin = nn.Linear(in_features=H // 2, out_features=V, bias=bias).to(dtype=dtype) + self.teacher_lin = nn.Linear(in_features=H, out_features=V, bias=bias).to(dtype=dtype) + self.cosine_loss = HFCosineLoss( + ignore_index=ignore_index, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + temperature=temperature, + ).get_batch_loss_metrics + + def forward(self, student: torch.Tensor, teacher: torch.Tensor, target: torch.Tensor): + return self.cosine_loss(student, self.student_lin.weight, teacher, self.teacher_lin.weight, target) + + +class LigerCosineSimilarityLoss(nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + bias: bool = False, + ): + super().__init__() + self.student_lin = nn.Linear(in_features=H // 2, out_features=V, bias=bias).to(dtype=dtype) + self.teacher_lin = nn.Linear(in_features=H, out_features=V, bias=bias).to(dtype=dtype) + self.weight_hard_loss = weight_hard_loss + self.weight_soft_loss = weight_soft_loss + self.ignore_index = ignore_index + self.temperature = temperature + self.cosine_loss = LigerFusedLinearCosineSimilarityFunction.apply + + def forward(self, student: torch.Tensor, teacher: torch.Tensor, target: torch.Tensor): + return self.cosine_loss( + student, + self.student_lin.weight, + teacher, + self.teacher_lin.weight, + target, + self.student_lin.bias, + self.teacher_lin.bias, + self.weight_hard_loss, + self.weight_soft_loss, + ) + + +def bench_memory_cosine_similarity_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + BT = input.x + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + bias = input.extra_benchmark_config["bias"] + weight_hard_loss = input.extra_benchmark_config["weight_hard_loss"] + weight_soft_loss = input.extra_benchmark_config["weight_soft_loss"] + ignore_index = input.extra_benchmark_config["ignore_index"] + provider = input.kernel_provider + + torch_cosine_loss = TorchCosineSimilarityLoss( + H=H, + V=V, + dtype=dtype, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + bias=bias, + ).to(device) + liger_cosine_loss = LigerCosineSimilarityLoss( + H=H, + V=V, + dtype=dtype, + ignore_index=ignore_index, + bias=bias, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + ).to(device) + + _tensor = torch.rand(BT, H // 2, device=device, dtype=dtype) + student_input1 = _tensor.detach().clone().requires_grad_(True) + student_input2 = _tensor.detach().clone().requires_grad_(True) + + teacher_input = torch.rand(BT, H, device=device, dtype=dtype) + + target = torch.randint(0, V, (BT,), device=device, dtype=torch.long) + + def fwd(): + if provider == "liger": + return liger_cosine_loss(student_input1, teacher_input, target) + elif provider == "torch": + return torch_cosine_loss(student_input2, teacher_input, target) + + def full(): + y = fwd() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +def bench_speed_cosine_similarity_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + BT = input.x + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + bias = input.extra_benchmark_config["bias"] + weight_hard_loss = input.extra_benchmark_config["weight_hard_loss"] + weight_soft_loss = input.extra_benchmark_config["weight_soft_loss"] + ignore_index = input.extra_benchmark_config["ignore_index"] + provider = input.kernel_provider + mode = input.kernel_operation_mode + + torch_cosine_loss = TorchCosineSimilarityLoss( + H=H, + V=V, + dtype=dtype, + ignore_index=ignore_index, + bias=bias, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + ).to(device) + + liger_cosine_loss = LigerCosineSimilarityLoss( + H=H, + V=V, + dtype=dtype, + ignore_index=ignore_index, + bias=bias, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + ).to(device) + + _tensor = torch.rand(BT, H // 2, device=device, dtype=dtype) + student_input1 = _tensor.detach().clone().requires_grad_(True) + student_input2 = _tensor.detach().clone().requires_grad_(True) + + teacher_input = torch.rand(BT, H, device=device, dtype=dtype) + + target = torch.randint(0, V, (BT,), device=device, dtype=torch.long) + + def fwd(): + if provider == "liger": + return liger_cosine_loss(student_input1, teacher_input, target) + elif provider == "torch": + return torch_cosine_loss(student_input2, teacher_input, target) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[student_input1, student_input2], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + rep=100, + quantiles=QUANTILES, + ) + + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "distill_cosine_loss", + "x_name": "BT", + "x_label": "B x T", + "x_values": [2**i for i in range(10, 14)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "H": 4096, + "V": 128256, + "mode": "forward", + "dtype": torch.bfloat16, + "bias": False, + "weight_hard_loss": 0.5, + "weight_soft_loss": 0.5, + "ignore_index": -100, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_cosine_similarity_loss, + kernel_operation_modes=["forward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + + run_benchmarks( + bench_test_fn=bench_memory_cosine_similarity_loss, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_distill_jsd_loss.py b/benchmark/scripts/benchmark_distill_jsd_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..324418e17aea8816846a0fd59330828add136fa2 --- /dev/null +++ b/benchmark/scripts/benchmark_distill_jsd_loss.py @@ -0,0 +1,272 @@ +import os +import sys + +import torch +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction +from liger_kernel.utils import get_total_gpu_memory +from liger_kernel.utils import infer_device + +device = infer_device() + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + + +class TorchJSDLoss(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + bias: bool = False, + ): + from test.chunked_loss.test_jsd_loss import HFJSDLoss + + super().__init__() + self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=bias, dtype=dtype) + self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.jsd_loss = HFJSDLoss( + ignore_index=ignore_index, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + temperature=temperature, + ).get_batch_loss_metrics + + def forward(self, student, teacher, target): + return self.jsd_loss( + student, + self.student_lin.weight, + teacher, + self.teacher_lin.weight, + target, + ) + + +class LigerJSDLoss(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + bias: bool = False, + ): + super().__init__() + self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=bias, dtype=dtype) + self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.weight_hard_loss = weight_hard_loss + self.weight_soft_loss = weight_soft_loss + self.ignore_index = ignore_index + self.temperature = temperature + self.jsd_loss = LigerFusedLinearJSDFunction.apply + + def forward(self, student, teacher, target): + return self.jsd_loss( + student, + self.student_lin.weight, + teacher, + self.teacher_lin.weight, + target, + self.student_lin.bias, + self.teacher_lin.bias, + self.weight_hard_loss, + self.weight_soft_loss, + ) + + +def bench_memory_jsd_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + BT = input.x + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + bias = input.extra_benchmark_config["bias"] + weight_hard_loss = input.extra_benchmark_config["weight_hard_loss"] + weight_soft_loss = input.extra_benchmark_config["weight_soft_loss"] + ignore_index = input.extra_benchmark_config["ignore_index"] + provider = input.kernel_provider + + torch_jsd_loss = TorchJSDLoss( + H=H, + V=V, + dtype=dtype, + ignore_index=ignore_index, + bias=bias, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + ).to(device) + liger_jsd_loss = LigerJSDLoss( + H=H, + V=V, + dtype=dtype, + ignore_index=ignore_index, + bias=bias, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + ).to(device) + + _tensor = torch.rand(BT, H // 2, device=device, dtype=dtype) + student_input1 = _tensor.detach().clone().requires_grad_(True) + student_input2 = _tensor.detach().clone().requires_grad_(True) + + teacher_input = torch.rand(BT, H, device=device, dtype=dtype) + + target = torch.randint(0, V, (BT,), device=device, dtype=torch.long) + + def fwd(): + if provider == "liger": + return liger_jsd_loss(student_input1, teacher_input, target) + elif provider == "torch": + return torch_jsd_loss(student_input2, teacher_input, target) + + def full(): + y = fwd() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +def bench_speed_jsd_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + BT = input.x + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + bias = input.extra_benchmark_config["bias"] + weight_hard_loss = input.extra_benchmark_config["weight_hard_loss"] + weight_soft_loss = input.extra_benchmark_config["weight_soft_loss"] + ignore_index = input.extra_benchmark_config["ignore_index"] + provider = input.kernel_provider + mode = input.kernel_operation_mode + + torch_jsd_loss = TorchJSDLoss( + H=H, + V=V, + dtype=dtype, + ignore_index=ignore_index, + bias=bias, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + ).to(device) + liger_jsd_loss = LigerJSDLoss( + H=H, + V=V, + dtype=dtype, + ignore_index=ignore_index, + bias=bias, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + ).to(device) + + _tensor = torch.rand(BT, H // 2, device=device, dtype=dtype) + student_input1 = _tensor.detach().clone().requires_grad_(True) + student_input2 = _tensor.detach().clone().requires_grad_(True) + + teacher_input = torch.rand(BT, H, device=device, dtype=dtype) + + target = torch.randint(0, V, (BT,), device=device, dtype=torch.long) + + def fwd(): + if provider == "liger": + return liger_jsd_loss(student_input1, teacher_input, target) + elif provider == "torch": + return torch_jsd_loss(student_input2, teacher_input, target) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[student_input1, student_input2], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + rep=100, + quantiles=QUANTILES, + ) + + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + gpu_memory_gbs = get_total_gpu_memory() + # We know that the full test will require 69GBs for vocab size 2^13 and 39GBs for vocab size 2^12 on torch + if gpu_memory_gbs >= 69: + x_max = 13 + elif gpu_memory_gbs >= 39: + x_max = 12 + else: + x_max = 11 + + common_configs = { + "kernel_name": "distill_jsd_loss", + "x_name": "BT", + "x_label": "B x T", + "x_values": [2**i for i in range(10, x_max + 1)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "H": 4096, + "V": 128256, + "mode": "forward", + "dtype": torch.bfloat16, + "bias": False, + "weight_hard_loss": 0.5, + "weight_soft_loss": 0.5, + "ignore_index": -100, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_jsd_loss, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + + run_benchmarks( + bench_test_fn=bench_memory_jsd_loss, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_dpo_loss.py b/benchmark/scripts/benchmark_dpo_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..228a228d55753042b0d1bf9471085076e1eefe3b --- /dev/null +++ b/benchmark/scripts/benchmark_dpo_loss.py @@ -0,0 +1,179 @@ +import os +import sys + +import torch +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.utils import infer_device + +device = infer_device() + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + + +def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + from test.chunked_loss.test_dpo_loss import LigerLMHeadDPO + from test.chunked_loss.test_dpo_loss import TorchLMHeadDPO + + B = input.x + T = input.extra_benchmark_config["T"] + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + bias = input.extra_benchmark_config["bias"] + beta = input.extra_benchmark_config["beta"] + ignore_index = input.extra_benchmark_config["ignore_index"] + provider = input.kernel_provider + + # Instantiate once and retrieve the first output only + torch_dpo_loss = TorchLMHeadDPO(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device) + liger_dpo_loss = LigerLMHeadDPO(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device) + torch_fwd = lambda x, ref_x, target: torch_dpo_loss(x, ref_x, target)[0] + liger_fwd = lambda x, ref_x, target: liger_dpo_loss(x, ref_x, target)[0] + + # Input shape: [B, T, H] + _input = torch.randn(B, T, H, device=device, dtype=dtype) + ref_input = torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) + # Target shape: [B, T] + target = torch.randint(V, (B, T), dtype=torch.long, device=device) + + # Add ignore_index tokens to simulate padding + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + def fwd(): + if provider == "liger": + return liger_fwd(_input, ref_input, target) + elif provider == "huggingface": + return torch_fwd(_input, ref_input, target) + + def full(): + y = fwd() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + from test.chunked_loss.test_dpo_loss import LigerLMHeadDPO + from test.chunked_loss.test_dpo_loss import TorchLMHeadDPO + + B = input.x + T = input.extra_benchmark_config["T"] + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + bias = input.extra_benchmark_config["bias"] + beta = input.extra_benchmark_config["beta"] + ignore_index = input.extra_benchmark_config["ignore_index"] + provider = input.kernel_provider + mode = input.kernel_operation_mode + + # Instantiate once and retrieve the first output only + torch_dpo_loss = TorchLMHeadDPO(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device) + liger_dpo_loss = LigerLMHeadDPO(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device) + torch_fwd = lambda x, ref_x, target: torch_dpo_loss(x, ref_x, target)[0] + liger_fwd = lambda x, ref_x, target: liger_dpo_loss(x, ref_x, target)[0] + + # Input shape: [B, T, H] + _input = torch.randn(B, T, H, device=device, dtype=dtype) + ref_input = torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) + # Target shape: [B, T] + target = torch.randint(V, (B, T), device=device, dtype=torch.long) + + # Add ignore_index tokens + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + def fwd(): + if provider == "liger": + return liger_fwd(_input, ref_input, target) + elif provider == "huggingface": + return torch_fwd(_input, ref_input, target) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[_input], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + rep=100, + quantiles=QUANTILES, + ) + + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "dpo_loss", + "x_name": "B", + "x_label": "Batch Size (B)", + "x_values": [2**i for i in range(1, 6)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "T": 512, + "H": 1024, + "V": 128256, + "mode": "forward", + "dtype": torch.bfloat16, + "bias": True, + "beta": 0.1, + "ignore_index": 42, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_dpo_loss, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + + run_benchmarks( + bench_test_fn=bench_memory_dpo_loss, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_dyt.py b/benchmark/scripts/benchmark_dyt.py new file mode 100755 index 0000000000000000000000000000000000000000..2c5129000d93001f4c585b58e6b68d143e0685cf --- /dev/null +++ b/benchmark/scripts/benchmark_dyt.py @@ -0,0 +1,96 @@ +import os +import sys + +import torch + +from benchmark_model_configs import compute_hidden_size_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import parse_benchmark_script_args +from utils import run_benchmarks +from utils import run_memory_benchmark +from utils import run_speed_benchmark + +from liger_kernel.utils import infer_device + +device = infer_device() + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + + +def _setup_dyt(input: SingleBenchmarkRunInput): + """Create input tensor and DyT layer from benchmark config.""" + from test.transformers.test_dyt import LigerDyT + from test.transformers.test_dyt import TorchDyT + + cfg = input.extra_benchmark_config + hidden_size = input.x + x = torch.randn(cfg["BT"], hidden_size, device=device, dtype=cfg["dtype"], requires_grad=True) + if input.kernel_provider == "liger": + layer = LigerDyT(hidden_size=hidden_size, beta=cfg["beta"]).to(device) + elif input.kernel_provider == "torch": + layer = TorchDyT(hidden_size=hidden_size, beta=cfg["beta"]).to(device) + elif input.kernel_provider == "torch_compile": + layer = torch.compile(TorchDyT(hidden_size=hidden_size, beta=cfg["beta"]).to(device)) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for DyT") + return x, layer + + +def bench_speed_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, layer = _setup_dyt(input) + return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x]) + + +def bench_memory_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, layer = _setup_dyt(input) + return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) + + +BT = 4096 + +if __name__ == "__main__": + args = parse_benchmark_script_args() + model = get_benchmark_model_config(args.model) + + for beta in [False, True]: + + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=model.hidden_size, + kernel_provider="torch", + extra_benchmark_config={"BT": BT, "dtype": model.dtype, "beta": beta}, + ) + x, layer = _setup_dyt(probe_input) + return layer(x) + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + sweep_config = compute_hidden_size_sweep_config(model, peak_bytes, bt=BT) + x_values = [1024 * i for i in range(1, 17) if 1024 * i <= sweep_config.max_hidden_size] or [model.hidden_size] + + common_configs = { + "kernel_name": f"dyt_beta={beta}", + "x_name": "hidden_size", + "x_label": "hidden_size", + "x_values": x_values, + "kernel_providers": ["liger", "torch", "torch_compile"], + "extra_benchmark_configs": [{"BT": sweep_config.bt, "dtype": model.dtype, "beta": beta}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_dyt, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_dyt, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_embedding.py b/benchmark/scripts/benchmark_embedding.py new file mode 100755 index 0000000000000000000000000000000000000000..2bd0c60be9735017eb1ab219eddb8b49773d360d --- /dev/null +++ b/benchmark/scripts/benchmark_embedding.py @@ -0,0 +1,134 @@ +import torch +import triton + +from torch.nn import Embedding +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.transformers.experimental.embedding import LigerEmbedding +from liger_kernel.utils import infer_device + +device = infer_device() + +# NOTE: For torch compile, we will just use default inductor settings. No further customization +# is needed. + + +def bench_speed_embedding(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + V = input.x + provider = input.kernel_provider + mode = input.kernel_operation_mode + + B = input.extra_benchmark_config["B"] + T = input.extra_benchmark_config["T"] + D = input.extra_benchmark_config["D"] + dtype = input.extra_benchmark_config["dtype"] + + torch_emb = Embedding(V, D).to(device).to(dtype) + liger_emb = LigerEmbedding(V, D).to(device).to(dtype) + torch_compile_emb = torch.compile(torch_emb) + + input_ids = torch.randint(0, V, (B, T), device=device) + + def fwd(): + if provider == "liger": + return liger_emb(input_ids) + elif provider == "torch_compile": + return torch_compile_emb(input_ids) + else: + return torch_emb(input_ids) + + def full(): + output = fwd() + output.backward(torch.randn_like(output)) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100) + elif mode == "backward": + output = fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: output.backward(torch.randn_like(output), retain_graph=True), + quantiles=QUANTILES, + grad_to_none=[input_ids], + rep=100, + ) + elif mode == "full": + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100) + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +def bench_memory_embedding(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + V = input.x + provider = input.kernel_provider + + B = input.extra_benchmark_config["B"] + T = input.extra_benchmark_config["T"] + D = input.extra_benchmark_config["D"] + dtype = input.extra_benchmark_config["dtype"] + + torch_emb = Embedding(V, D).to(device).to(dtype) + liger_emb = LigerEmbedding(V, D).to(device).to(dtype) + torch_compile_emb = torch.compile(torch_emb) + + input_ids = torch.randint(0, V, (B, T), device=device) + + def fwd(): + if provider == "liger": + return liger_emb(input_ids) + elif provider == "torch_compile": + return torch_compile_emb(input_ids) + else: + return torch_emb(input_ids) + + def full(): + output = fwd() + output.backward(torch.randn_like(output)) + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "embedding", + "x_name": "V", + "x_label": "embedding dimension", + "x_values": [2**i for i in range(10, 18)], + "kernel_providers": ["liger", "huggingface", "torch_compile"], + "extra_benchmark_configs": [ + # BERT + {"B": 32, "T": 512, "D": 768, "dtype": torch.float32}, + # Llama + {"B": 8, "T": 2048, "D": 4096, "dtype": torch.float32}, + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_embedding, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_embedding, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_fused_add_rms_norm.py b/benchmark/scripts/benchmark_fused_add_rms_norm.py new file mode 100755 index 0000000000000000000000000000000000000000..935871e9065a0a65c555ba7097a9be30866565f3 --- /dev/null +++ b/benchmark/scripts/benchmark_fused_add_rms_norm.py @@ -0,0 +1,201 @@ +import torch +import torch.nn as nn +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.transformers.fused_add_rms_norm import LigerFusedAddRMSNorm +from liger_kernel.transformers.rms_norm import LigerRMSNorm +from liger_kernel.utils import infer_device + +device = infer_device() + + +class NaiveAddRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Naive implementation of the add residual rms norm. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states, residual): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + residual = residual.to(torch.float32) + hidden_states = hidden_states + residual + residual = hidden_states + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype), residual.to(input_dtype) + + +class AddLigerRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + AddLigerRMSNorm is equivalent to NaiveAddRMSNorm class above, but uses the LigerRMSNorm kernel. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.rms_norm = LigerRMSNorm(hidden_size, eps, in_place=False) + + def forward(self, hidden_states, residual): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + residual = residual.to(torch.float32) + hidden_states = hidden_states + residual + residual = hidden_states + hidden_states = self.rms_norm(hidden_states) + return self.weight * hidden_states.to(input_dtype), residual.to(input_dtype) + + +def bench_speed_fused_residual_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + N = input.x + provider = input.kernel_provider + mode = input.kernel_operation_mode + + extra_benchmark_config = input.extra_benchmark_config + M = extra_benchmark_config["M"] + eps = extra_benchmark_config["eps"] + dtype = extra_benchmark_config["dtype"] + + x_shape = (M, N) + + # Fused Add RMS Norm + fused_add_rms_norm = LigerFusedAddRMSNorm(hidden_size=N, eps=eps).to(device) + # Naive implementation + naive_rms_norm = NaiveAddRMSNorm(hidden_size=N, eps=eps).to(device) + # LigerRMSNorm without fused residual addition + liger_rms_norm = AddLigerRMSNorm(hidden_size=N, eps=eps).to(device) + + x = torch.randn(x_shape, dtype=dtype, device=device) + r = torch.randn(x_shape, dtype=dtype, device=device) + dy = torch.randn_like(x) + ds = torch.randn_like(r) + x.requires_grad_(True) + r.requires_grad_(True) + # utility functions + + def y_fwd(): + if provider == "liger_fused_add_rms_norm": + return fused_add_rms_norm(x, r) + + if provider == "huggingface": + return naive_rms_norm(x, r) + + if provider == "liger_rms_norm": + return liger_rms_norm(x, r) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + y_fwd, + grad_to_none=[x, r], + rep=500, + quantiles=QUANTILES, + ) + elif mode == "backward": + y, s = y_fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: torch.autograd.backward((y, s), (dy, ds), retain_graph=True), + grad_to_none=[x, r], + rep=500, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y, s = y_fwd() + torch.autograd.backward((y, s), (dy, ds)) + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + grad_to_none=[x, r], + rep=500, + quantiles=QUANTILES, + ) + + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +def bench_memory_fused_residual_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + N = input.x + provider = input.kernel_provider + + extra_benchmark_config = input.extra_benchmark_config + M = extra_benchmark_config["M"] + eps = extra_benchmark_config["eps"] + dtype = extra_benchmark_config["dtype"] + + x_shape = (M, N) + + fused_add_rms_norm = LigerFusedAddRMSNorm(hidden_size=N, eps=eps).to(device) + naive_rms_norm = NaiveAddRMSNorm(hidden_size=N, eps=eps).to(device) + liger_rms_norm = AddLigerRMSNorm(hidden_size=N, eps=eps).to(device) + + x = torch.randn(x_shape, dtype=dtype, device=device) + r = torch.randn(x_shape, dtype=dtype, device=device) + dy = torch.randn_like(x) + ds = torch.randn_like(r) + x.requires_grad_(True) + r.requires_grad_(True) + + # utility functions + def y_fwd(): + if provider == "liger_fused_add_rms_norm": + return fused_add_rms_norm(x, r) + if provider == "huggingface": + return naive_rms_norm(x, r) + if provider == "liger_rms_norm": + return liger_rms_norm(x, r) + + def full(): + y, s = y_fwd() + torch.autograd.backward((y, s), (dy, ds)) + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "fused_add_rms_norm", + "x_name": "H", + "x_label": "hidden size", + "x_values": [2**i for i in range(10, 16)], + "kernel_providers": ["liger_fused_add_rms_norm", "huggingface", "liger_rms_norm"], + "extra_benchmark_configs": [{"M": 2048, "dtype": torch.float32, "eps": 1e-6}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_fused_residual_rms_norm, + kernel_operation_modes=["forward", "full", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_fused_residual_rms_norm, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_fused_linear_cross_entropy.py b/benchmark/scripts/benchmark_fused_linear_cross_entropy.py new file mode 100755 index 0000000000000000000000000000000000000000..4d36a66a6394ec7f9104d0300dd599acf332c29b --- /dev/null +++ b/benchmark/scripts/benchmark_fused_linear_cross_entropy.py @@ -0,0 +1,184 @@ +import torch +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss +from liger_kernel.utils import infer_device + +device = infer_device() + + +class TorchLMHeadCE(torch.nn.Module): + """Ground truth implementation of the linear fused with torch based cross entropy loss. + + :param H: hidden size + :param V: vocab size + :param ignore_index: index to ignore + :param reduction: reduction method + """ + + def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): + super().__init__() + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype) + self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction="mean") + + def forward(self, x, y): + logits = self.lin(x) + return self.ce_loss(logits, y) + + +class LigerLMHeadCE(torch.nn.Module): + def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100, accum_dtype=None): + super().__init__() + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype) + self.ce_loss = LigerFusedLinearCrossEntropyLoss( + ignore_index=ignore_index, reduction="mean", accum_dtype=accum_dtype + ) + + def forward(self, x, y): + return self.ce_loss(self.lin.weight, x, y) + + +############################################################################# +# Test the memory consumption of the linear fused cross entropy loss +############################################################################# + + +def bench_memory_fused_linear_cross_entropy( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + BT = input.x + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + provider = input.kernel_provider + + lm_head_ce = None + if provider == "liger": + lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device) + elif provider == "liger-fp32-accum": + lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype, accum_dtype=torch.float32).to(device) + else: + lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device) + + _input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device) + target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1) + + def fwd(): + return lm_head_ce(_input, target) + + def full(): + y = fwd() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +# ############################################################################# +# # Test the speed of the fused linear cross entropy loss +# ############################################################################# + + +def bench_speed_fused_linear_cross_entropy( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + BT = input.x + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + provider = input.kernel_provider + mode = input.kernel_operation_mode + + lm_head_ce = None + if provider == "liger": + lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device) + elif provider == "liger-fp32-accum": + lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype, accum_dtype=torch.float32).to(device) + else: + lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device) + + _input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device) + target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1) + + def fwd(): + return lm_head_ce(_input, target) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "no-grad-forward": + with torch.no_grad(): + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = fwd() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[_input], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + rep=100, + quantiles=QUANTILES, + ) + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "fused_linear_cross_entropy", + "x_name": "BT", + "x_label": "B x T", + "x_values": [2**i for i in range(12, 16)], + "kernel_providers": ["liger", "liger-fp32-accum", "huggingface"], + "extra_benchmark_configs": [{"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_fused_linear_cross_entropy, + kernel_operation_modes=["forward", "backward", "full", "no-grad-forward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_fused_linear_cross_entropy, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_fused_linear_jsd.py b/benchmark/scripts/benchmark_fused_linear_jsd.py new file mode 100755 index 0000000000000000000000000000000000000000..ac62863b216a94f4fd9f9970145e394bbe20ebd8 --- /dev/null +++ b/benchmark/scripts/benchmark_fused_linear_jsd.py @@ -0,0 +1,260 @@ +import torch +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD +from liger_kernel.utils import infer_device + +device = infer_device() + + +class TorchJSD(torch.nn.Module): + def __init__( + self, + beta: float = 0.5, + ignore_index: int = -100, + dtype: torch.dtype = torch.float, + ): + super(TorchJSD, self).__init__() + self.kl = torch.nn.KLDivLoss(reduction="none", log_target=True) + self.beta = beta + self.ignore_index = ignore_index + self.dtype = dtype + + def forward( + self, + log_q: torch.Tensor, # input + log_p: torch.Tensor, # target + label=None, + ): + log_p, log_q = log_p.to(torch.float), log_q.to(torch.float) + log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1)) + m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta) + loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + (1 - self.beta) * self.kl( + torch.log(m), log_q + ).sum(dim=-1) + + if label is not None: + loss = torch.where(label != self.ignore_index, loss, 0.0) + n_non_ignore = (label != self.ignore_index).sum().item() + if n_non_ignore == 0: + loss = 0.0 + else: + loss = (loss / n_non_ignore).sum() + else: + loss = (loss / log_q.shape[0]).sum() + return loss.to(self.dtype) + + +class TorchLMHeadJSD(torch.nn.Module): + """Ground truth implementation of the linear fused with torch based jsd loss. + + :param H: hidden size + :param V: vocab size + :param temperature: softmax temperature + :param beta: jsd beta + """ + + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + device: torch.device, + beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + ): + super().__init__() + self.student_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device) + self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device) + self.jsd = TorchJSD(beta=beta, ignore_index=ignore_index, dtype=dtype) + self.temperature = temperature + + def forward(self, student_input, teacher_input, label=None): + student_logits = self.student_lin(student_input) + teacher_logits = self.teacher_lin(teacher_input) + student_prob = torch.log_softmax(student_logits / self.temperature, dim=-1) + teacher_prob = torch.log_softmax(teacher_logits / self.temperature, dim=-1) + + return self.jsd(student_prob, teacher_prob, label) + + +class LigerLMHeadJSD(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + device: torch.device, + beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + ): + super().__init__() + self.student_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device) + self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device) + self.fused_jsd = LigerFusedLinearJSD(jsd_beta=beta, ignore_index=ignore_index, temperature=temperature) + + def forward(self, student_input, teacher_input, label=None): + return self.fused_jsd( + student_input, + self.student_lin.weight, + teacher_input, + self.teacher_lin.weight, + label, + ) + + +############################################################################# +# Test the memory consumption of the fused linear JSD +############################################################################# + + +def bench_memory_fused_linear_jsd( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + BT = input.x + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + provider = input.kernel_provider + + torch_lm_head_jsd = TorchLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) + liger_lm_head_jsd = LigerLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) + + # init the linear in all FusedLinearJSDs with the same weights + torch_lm_head_jsd.student_lin.weight.data = liger_lm_head_jsd.student_lin.weight.data = torch.rand( + V, H, device=device, dtype=dtype + ) + torch_lm_head_jsd.teacher_lin.weight.data = liger_lm_head_jsd.teacher_lin.weight.data = torch.rand( + V, H, device=device, dtype=dtype + ) + + student_input = torch.rand(BT, H, requires_grad=True, dtype=dtype, device=device) + teacher_input = torch.rand(BT, H, dtype=dtype, device=device) + + def fwd(): + if provider == "liger": + return liger_lm_head_jsd(student_input, teacher_input) + elif provider == "torch": + return torch_lm_head_jsd(student_input, teacher_input) + + def full(): + y = fwd() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +# ############################################################################# +# # Test the speed of the fused linear JSD +# ############################################################################# + + +def bench_speed_fused_linear_jsd( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + BT = input.x + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + mode = input.kernel_operation_mode + + dtype = input.extra_benchmark_config["dtype"] + provider = input.kernel_provider + + torch_lm_head_jsd = TorchLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) + liger_lm_head_jsd = LigerLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) + + # init the linear in all FusedLinearJSDs with the same weights + torch_lm_head_jsd.student_lin.weight.data = liger_lm_head_jsd.student_lin.weight.data = torch.rand( + V, H, device=device, dtype=dtype + ) + torch_lm_head_jsd.teacher_lin.weight.data = liger_lm_head_jsd.teacher_lin.weight.data = torch.rand( + V, H, device=device, dtype=dtype + ) + + student_input = torch.rand(BT, H, requires_grad=True, dtype=dtype, device=device) + teacher_input = torch.rand(BT, H, dtype=dtype, device=device) + + def fwd(): + if provider == "liger": + return liger_lm_head_jsd(student_input, teacher_input) + elif provider == "torch": + return torch_lm_head_jsd(student_input, teacher_input) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = fwd() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[ + student_input, + torch_lm_head_jsd.student_lin.weight, + torch_lm_head_jsd.teacher_lin.weight, + ], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + rep=100, + quantiles=QUANTILES, + ) + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "fused_linear_jsd", + "x_name": "BT", + "x_label": "B x T", + "x_values": [2**i for i in range(10, 14)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [{"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_fused_linear_jsd, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_fused_linear_jsd, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_fused_neighborhood_attention.py b/benchmark/scripts/benchmark_fused_neighborhood_attention.py new file mode 100755 index 0000000000000000000000000000000000000000..515d65cad090d098db93624d99e206a73d602330 --- /dev/null +++ b/benchmark/scripts/benchmark_fused_neighborhood_attention.py @@ -0,0 +1,367 @@ +import math + +import torch +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.transformers.fused_neighborhood_attention import LigerFusedNeighborhoodAttention +from liger_kernel.utils import infer_device + +device = infer_device() + + +class TorchNeighborhoodAttention(torch.nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + kernel_size: int = 7, + dilation: int = 1, + bias: bool = True, + dropout: float = 0.0, + scale: float = None, + ): + super().__init__() + + if hidden_size % num_heads != 0: + raise ValueError(f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})") + + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.kernel_size = kernel_size + self.dilation = dilation + self.scale = scale if scale is not None else 1.0 / math.sqrt(self.head_dim) + + self.q_proj = torch.nn.Linear(hidden_size, hidden_size, bias=bias) + self.k_proj = torch.nn.Linear(hidden_size, hidden_size, bias=bias) + self.v_proj = torch.nn.Linear(hidden_size, hidden_size, bias=bias) + self.out_proj = torch.nn.Linear(hidden_size, hidden_size, bias=bias) + + if dropout > 0.0: + self.dropout = torch.nn.Dropout(dropout) + else: + self.dropout = None + + def _create_neighborhood_mask(self, seq_len: int, device: torch.device) -> torch.Tensor: + mask = torch.zeros(seq_len, seq_len, device=device, dtype=torch.bool) + half_kernel = self.kernel_size // 2 + + for i in range(seq_len): + start = max(0, i - half_kernel * self.dilation) + end = min(seq_len, i + half_kernel * self.dilation + 1) + + for j in range(start, end): + if self.dilation == 1 or (j - i) % self.dilation == 0: + mask[i, j] = True + + return mask + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, hidden_size = hidden_states.shape + + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + + scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale + + mask = self._create_neighborhood_mask(seq_len, hidden_states.device) + scores = scores.masked_fill(~mask, float("-inf")) + + attn_weights = torch.softmax(scores, dim=-1) + + if self.dropout is not None: + attn_weights = self.dropout(attn_weights) + + attn_output = torch.matmul(attn_weights, value) + + attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size) + + output = self.out_proj(attn_output) + + return output + + +def bench_speed_fused_neighborhood_attention(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + seq_len = input.x + provider = input.kernel_provider + mode = input.kernel_operation_mode + + extra_benchmark_config = input.extra_benchmark_config + batch_size = extra_benchmark_config["batch_size"] + hidden_size = extra_benchmark_config["hidden_size"] + num_heads = extra_benchmark_config["num_heads"] + kernel_size = extra_benchmark_config["kernel_size"] + dilation = extra_benchmark_config["dilation"] + bias = extra_benchmark_config["bias"] + dtype = extra_benchmark_config["dtype"] + + x_shape = (batch_size, seq_len, hidden_size) + + liger_attn = ( + LigerFusedNeighborhoodAttention( + hidden_size=hidden_size, + num_heads=num_heads, + kernel_size=kernel_size, + dilation=dilation, + bias=bias, + dropout=0.0, + ) + .to(device) + .to(dtype) + ) + + torch_attn = ( + TorchNeighborhoodAttention( + hidden_size=hidden_size, + num_heads=num_heads, + kernel_size=kernel_size, + dilation=dilation, + bias=bias, + dropout=0.0, + ) + .to(device) + .to(dtype) + ) + + with torch.no_grad(): + torch_attn.q_proj.weight.copy_(liger_attn.q_proj.weight) + torch_attn.k_proj.weight.copy_(liger_attn.k_proj.weight) + torch_attn.v_proj.weight.copy_(liger_attn.v_proj.weight) + torch_attn.out_proj.weight.copy_(liger_attn.out_proj.weight) + + if bias: + torch_attn.q_proj.bias.copy_(liger_attn.q_proj.bias) + torch_attn.k_proj.bias.copy_(liger_attn.k_proj.bias) + torch_attn.v_proj.bias.copy_(liger_attn.v_proj.bias) + torch_attn.out_proj.bias.copy_(liger_attn.out_proj.bias) + + x = torch.randn(x_shape, dtype=dtype, device=device) + dy = torch.randn_like(x) + x.requires_grad_(True) + + def fwd(): + if provider == "liger": + return liger_attn(x) + elif provider == "torch": + return torch_attn(x) + + print(f"Starting Warmup for input size: {x_shape}") + _ = fwd() + if mode in ("backward", "full"): + y = _ + y.backward(dy, retain_graph=True) + print("Done Warmup") + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, grad_to_none=[x], rep=100, quantiles=QUANTILES) + elif mode == "backward": + y = fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(dy, retain_graph=True), + grad_to_none=[x], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward(dy, retain_graph=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=[x], rep=100, quantiles=QUANTILES) + + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +def bench_memory_fused_neighborhood_attention(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + seq_len = input.x + provider = input.kernel_provider + + extra_benchmark_config = input.extra_benchmark_config + batch_size = extra_benchmark_config["batch_size"] + hidden_size = extra_benchmark_config["hidden_size"] + num_heads = extra_benchmark_config["num_heads"] + kernel_size = extra_benchmark_config["kernel_size"] + dilation = extra_benchmark_config["dilation"] + bias = extra_benchmark_config["bias"] + dtype = extra_benchmark_config["dtype"] + + x_shape = (batch_size, seq_len, hidden_size) + + liger_attn = ( + LigerFusedNeighborhoodAttention( + hidden_size=hidden_size, + num_heads=num_heads, + kernel_size=kernel_size, + dilation=dilation, + bias=bias, + dropout=0.0, + ) + .to(device) + .to(dtype) + ) + + torch_attn = ( + TorchNeighborhoodAttention( + hidden_size=hidden_size, + num_heads=num_heads, + kernel_size=kernel_size, + dilation=dilation, + bias=bias, + dropout=0.0, + ) + .to(device) + .to(dtype) + ) + + with torch.no_grad(): + torch_attn.q_proj.weight.copy_(liger_attn.q_proj.weight) + torch_attn.k_proj.weight.copy_(liger_attn.k_proj.weight) + torch_attn.v_proj.weight.copy_(liger_attn.v_proj.weight) + torch_attn.out_proj.weight.copy_(liger_attn.out_proj.weight) + + if bias: + torch_attn.q_proj.bias.copy_(liger_attn.q_proj.bias) + torch_attn.k_proj.bias.copy_(liger_attn.k_proj.bias) + torch_attn.v_proj.bias.copy_(liger_attn.v_proj.bias) + torch_attn.out_proj.bias.copy_(liger_attn.out_proj.bias) + + x = torch.randn(x_shape, dtype=dtype, device=device) + dy = torch.randn_like(x) + x.requires_grad_(True) + + def fwd(): + if provider == "liger": + return liger_attn(x) + elif provider == "torch": + return torch_attn(x) + + def full(): + y = fwd() + y.backward(dy, retain_graph=True) + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "fused_neighborhood_attention", + "x_name": "seq_len", + "x_label": "sequence length", + "x_values": [2**i for i in range(6, 13)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "batch_size": 2, + "hidden_size": 512, + "num_heads": 8, + "kernel_size": 7, + "dilation": 1, + "bias": True, + "dtype": torch.float32, + }, + { + "batch_size": 4, + "hidden_size": 768, + "num_heads": 12, + "kernel_size": 7, + "dilation": 1, + "bias": True, + "dtype": torch.float32, + }, + { + "batch_size": 2, + "hidden_size": 1024, + "num_heads": 16, + "kernel_size": 9, + "dilation": 1, + "bias": True, + "dtype": torch.float32, + }, + { + "batch_size": 2, + "hidden_size": 512, + "num_heads": 8, + "kernel_size": 7, + "dilation": 2, + "bias": True, + "dtype": torch.float32, + }, + { + "batch_size": 2, + "hidden_size": 512, + "num_heads": 8, + "kernel_size": 7, + "dilation": 1, + "bias": True, + "dtype": torch.bfloat16, + }, + { + "batch_size": 4, + "hidden_size": 768, + "num_heads": 12, + "kernel_size": 7, + "dilation": 1, + "bias": True, + "dtype": torch.bfloat16, + }, + { + "batch_size": 2, + "hidden_size": 1024, + "num_heads": 16, + "kernel_size": 9, + "dilation": 1, + "bias": True, + "dtype": torch.bfloat16, + }, + { + "batch_size": 2, + "hidden_size": 512, + "num_heads": 8, + "kernel_size": 7, + "dilation": 2, + "bias": True, + "dtype": torch.bfloat16, + }, + ], + } + + run_benchmarks( + bench_test_fn=bench_speed_fused_neighborhood_attention, + kernel_operation_modes=["forward", "full", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + + run_benchmarks( + bench_test_fn=bench_memory_fused_neighborhood_attention, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_geglu.py b/benchmark/scripts/benchmark_geglu.py new file mode 100755 index 0000000000000000000000000000000000000000..d59564bafa15b22cddb2d5cf7b6a64f01d5fa989 --- /dev/null +++ b/benchmark/scripts/benchmark_geglu.py @@ -0,0 +1,115 @@ +import math + +import torch + +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaMLP +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import parse_benchmark_script_args +from utils import run_benchmarks +from utils import run_memory_benchmark +from utils import run_speed_benchmark + +from liger_kernel.transformers.geglu import LigerGEGLUMLP +from liger_kernel.utils import infer_device + +device = infer_device() + + +def _setup_geglu(input: SingleBenchmarkRunInput): + """Create input tensor and GEGLU layer from benchmark config.""" + cfg = input.extra_benchmark_config + llama_config = LlamaConfig( + hidden_size=cfg["hidden_size"], + intermediate_size=cfg["intermediate_size"], + hidden_act=cfg["hidden_act"], + ) + x = torch.randn( + cfg["bsz"], + input.x, + cfg["hidden_size"], + device=device, + dtype=cfg["dtype"], + requires_grad=True, + ) + if input.kernel_provider == "liger": + layer = LigerGEGLUMLP(config=llama_config).to(device).to(cfg["dtype"]) + elif input.kernel_provider == "huggingface": + layer = LlamaMLP(config=llama_config).to(device).to(cfg["dtype"]) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for GEGLU") + return x, layer + + +def bench_speed_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, layer = _setup_geglu(input) + return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x]) + + +def bench_memory_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, layer = _setup_geglu(input) + return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + model = get_benchmark_model_config(args.model) + probe_seq_len = 1024 + + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_seq_len, + kernel_provider="huggingface", + extra_benchmark_config={ + "bsz": 1, + "hidden_size": model.hidden_size, + "intermediate_size": model.intermediate_size, + "hidden_act": "gelu_pytorch_tanh", + "dtype": model.dtype, + }, + ) + x, layer = _setup_geglu(probe_input) + return layer(x) + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_seq_len + + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "geglu", + "x_name": "T", + "x_label": "sequence length", + "x_values": [2**i for i in range(10, int(math.log2(config.seq_len)) + 1)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "bsz": config.batch_size, + "hidden_size": model.hidden_size, + "intermediate_size": model.intermediate_size, + "hidden_act": "gelu_pytorch_tanh", + "dtype": model.dtype, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_geglu, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_geglu, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_group_norm.py b/benchmark/scripts/benchmark_group_norm.py new file mode 100755 index 0000000000000000000000000000000000000000..5a8bf37f4af06260409085501c6a0329ae5bfdc4 --- /dev/null +++ b/benchmark/scripts/benchmark_group_norm.py @@ -0,0 +1,137 @@ +import torch +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.transformers.group_norm import LigerGroupNorm +from liger_kernel.utils import infer_device + +device = infer_device() + + +def bench_speed_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + C = input.x + provider = input.kernel_provider + mode = input.kernel_operation_mode + extra_benchmark_config = input.extra_benchmark_config + M = extra_benchmark_config["M"] + H = extra_benchmark_config["H"] + channels_per_group = extra_benchmark_config["channels_per_group"] + eps = extra_benchmark_config["eps"] + dtype = extra_benchmark_config["dtype"] + + x_shape = (M, C, H) + triton_ln = LigerGroupNorm(num_channels=C, num_groups=C // channels_per_group, eps=eps).to(device) + torch_ln = torch.nn.GroupNorm(num_groups=C // channels_per_group, num_channels=C, eps=eps).to(device) + + x = torch.randn(x_shape, dtype=dtype, device=device) + dy = torch.randn_like(x) + x.requires_grad_(True) + + def y_fwd(): + if provider == "liger": + return triton_ln(x) + if provider == "huggingface": + return torch_ln(x) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500) + elif mode == "backward": + y = y_fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(dy, retain_graph=True), + quantiles=QUANTILES, + grad_to_none=[x], + rep=500, + ) + elif mode == "full": + + def full(): + y = y_fwd() + y.backward(dy, retain_graph=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=[x], rep=500) + + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +def bench_memory_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + C = input.x + provider = input.kernel_provider + extra_benchmark_config = input.extra_benchmark_config + M = extra_benchmark_config["M"] + H = extra_benchmark_config["H"] + channels_per_group = extra_benchmark_config["channels_per_group"] + eps = extra_benchmark_config["eps"] + dtype = extra_benchmark_config["dtype"] + + x_shape = (M, C, H) + triton_ln = LigerGroupNorm(num_channels=C, num_groups=C // channels_per_group, eps=eps).to(device) + torch_ln = torch.nn.GroupNorm(num_groups=C // channels_per_group, num_channels=C, eps=eps).to(device) + + x = torch.randn(x_shape, dtype=dtype, device=device) + dy = torch.randn_like(x) + x.requires_grad_(True) + + def y_fwd(): + if provider == "liger": + return triton_ln(x) + if provider == "huggingface": + return torch_ln(x) + + def full(): + y = y_fwd() + y.backward(dy, retain_graph=True) + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "group_norm", + "x_name": "C", + "x_label": "num_channels", + "x_values": [2**i for i in range(5, 12)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "M": 128, + "H": 512, + "channels_per_group": 4, + "dtype": torch.float32, + "eps": 1e-6, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_group_norm, + kernel_operation_modes=["forward", "full", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_group_norm, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_grpo_loss.py b/benchmark/scripts/benchmark_grpo_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..497d8692c7ab5688637dcfe9c8c9ec955d8cb5e7 --- /dev/null +++ b/benchmark/scripts/benchmark_grpo_loss.py @@ -0,0 +1,234 @@ +import os +import sys + +import torch +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.utils import infer_device + +device = infer_device() + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + + +############################################################################# +# Test the memory consumption of the linear fused GRPO loss +############################################################################# + + +def bench_memory_fused_linear_grpo_loss( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + from test.chunked_loss.test_grpo_loss import LigerLMHeadGRPO + from test.chunked_loss.test_grpo_loss import TorchLMHeadGRPO + + B = input.x + T = input.extra_benchmark_config["T"] + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + importance_sampling_level = input.extra_benchmark_config["importance_sampling_level"] + provider = input.kernel_provider + + # Instantiate once and retrieve the first output only + torch_lm_head_grpo = TorchLMHeadGRPO(H=H, V=V, dtype=dtype, importance_sampling_level=importance_sampling_level).to( + device + ) + liger_lm_head_grpo = LigerLMHeadGRPO(H=H, V=V, dtype=dtype, importance_sampling_level=importance_sampling_level).to( + device + ) + + # Create inputs + _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) + selected_token_ids = torch.randint(0, V, (B, T), dtype=torch.long, device=device) + attention_mask = torch.ones(B, T, device=device) + advantages = torch.randn(B, dtype=dtype, device=device) + ref_input = torch.randn(B, T, H, dtype=dtype, device=device) + + torch_fwd = lambda: torch_lm_head_grpo(_input, selected_token_ids, attention_mask, advantages, ref_input=ref_input)[ + 0 + ] + liger_fwd = lambda: liger_lm_head_grpo(_input, selected_token_ids, attention_mask, advantages, ref_input=ref_input)[ + 0 + ] + + def fwd(): + if provider == "liger": + return liger_fwd() + elif provider == "torch": + return torch_fwd() + + def full(): + y = fwd() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +############################################################################# +# Test the speed of the fused linear GRPO loss +############################################################################# + + +def bench_speed_fused_linear_grpo_loss( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + from test.chunked_loss.test_grpo_loss import LigerLMHeadGRPO + from test.chunked_loss.test_grpo_loss import TorchLMHeadGRPO + + B = input.x + T = input.extra_benchmark_config["T"] + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + importance_sampling_level = input.extra_benchmark_config["importance_sampling_level"] + provider = input.kernel_provider + mode = input.kernel_operation_mode + + # Instantiate once and retrieve the first output only + torch_lm_head_grpo = TorchLMHeadGRPO(H=H, V=V, dtype=dtype, importance_sampling_level=importance_sampling_level).to( + device + ) + liger_lm_head_grpo = LigerLMHeadGRPO(H=H, V=V, dtype=dtype, importance_sampling_level=importance_sampling_level).to( + device + ) + + # Create inputs + _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) + selected_token_ids = torch.randint(0, V, (B, T), dtype=torch.long, device=device) + attention_mask = torch.ones(B, T, device=device) + advantages = torch.randn(B, dtype=dtype, device=device) + ref_input = torch.randn(B, T, H, dtype=dtype, device=device) + + torch_fwd = lambda: torch_lm_head_grpo(_input, selected_token_ids, attention_mask, advantages, ref_input=ref_input)[ + 0 + ] + liger_fwd = lambda: liger_lm_head_grpo(_input, selected_token_ids, attention_mask, advantages, ref_input=ref_input)[ + 0 + ] + + def fwd(): + if provider == "liger": + return liger_fwd() + elif provider == "torch": + return torch_fwd() + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = fwd() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[_input], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + rep=100, + quantiles=QUANTILES, + ) + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + # Benchmark token-level importance sampling (original GRPO) + token_configs = { + "kernel_name": "fused_linear_grpo_loss_token", + "x_name": "B", + "x_label": "B", + "x_values": [2**i for i in range(1, 5)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "T": 1024, + "H": 4096, + "V": 128256, + "importance_sampling_level": "token", + "dtype": torch.bfloat16, + } + ], + "overwrite": args.overwrite, + } + + # Benchmark sequence-level importance sampling (GSPO) + sequence_configs = { + "kernel_name": "fused_linear_grpo_loss_sequence", + "x_name": "B", + "x_label": "B", + "x_values": [2**i for i in range(1, 5)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "T": 1024, + "H": 4096, + "V": 128256, + "importance_sampling_level": "sequence", + "dtype": torch.bfloat16, + } + ], + "overwrite": args.overwrite, + } + + # Run benchmarks for token-level (GRPO) + print("Benchmarking GRPO (token-level importance sampling)...") + run_benchmarks( + bench_test_fn=bench_speed_fused_linear_grpo_loss, + kernel_operation_modes=["forward", "full", "backward"], + metric_name="speed", + metric_unit="ms", + **token_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_fused_linear_grpo_loss, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **token_configs, + ) + + # Run benchmarks for sequence-level (GSPO) + print("Benchmarking GSPO (sequence-level importance sampling)...") + run_benchmarks( + bench_test_fn=bench_speed_fused_linear_grpo_loss, + kernel_operation_modes=["forward", "full", "backward"], + metric_name="speed", + metric_unit="ms", + **sequence_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_fused_linear_grpo_loss, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **sequence_configs, + ) diff --git a/benchmark/scripts/benchmark_jsd.py b/benchmark/scripts/benchmark_jsd.py new file mode 100755 index 0000000000000000000000000000000000000000..16d71eac042f7736989d5e93e2279360cd0f68dd --- /dev/null +++ b/benchmark/scripts/benchmark_jsd.py @@ -0,0 +1,157 @@ +import torch +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.transformers.jsd import LigerJSD +from liger_kernel.utils import get_total_gpu_memory +from liger_kernel.utils import infer_device + +device = infer_device() + + +class TorchJSD(torch.nn.Module): + def __init__( + self, + beta: float = 0.5, + ignore_index: int = -100, + dtype: torch.dtype = torch.float, + ): + super(TorchJSD, self).__init__() + self.kl = torch.nn.KLDivLoss(reduction="none", log_target=True) + self.beta = beta + self.ignore_index = ignore_index + self.dtype = dtype + + def forward( + self, + log_q: torch.Tensor, # input + log_p: torch.Tensor, # target + label=None, + ): + log_p, log_q = log_p.to(torch.float), log_q.to(torch.float) + log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1)) + m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta) + loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + (1 - self.beta) * self.kl( + torch.log(m), log_q + ).sum(dim=-1) + + if label is not None: + loss = torch.where(label != self.ignore_index, loss, 0.0) + n_non_ignore = (label != self.ignore_index).sum().item() + if n_non_ignore == 0: + loss = 0.0 + else: + loss = (loss / n_non_ignore).sum() + else: + loss = (loss / log_q.shape[0]).sum() + return loss.to(self.dtype) + + +def bench_speed_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + V = input.x + B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] + torch_jsd = TorchJSD() + liger_jsd = LigerJSD() + + _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(dim=-1) + target = torch.randn(B * T, V, device=device).log_softmax(dim=-1) + + def fwd(): + if input.kernel_provider == "liger": + return liger_jsd(_input, target) + else: + return torch_jsd(_input, target) + + if input.kernel_operation_mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100) + elif input.kernel_operation_mode == "backward": + y = fwd() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + quantiles=QUANTILES, + grad_to_none=[_input], + rep=100, + ) + elif input.kernel_operation_mode == "full": + + def full(): + y = fwd() + y.backward(retain_graph=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100) + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +def bench_memory_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + torch_jsd = TorchJSD() + liger_jsd = LigerJSD() + + V = input.x + B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] + + _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(dim=-1) + target = torch.randn(B * T, V, device=device).log_softmax(dim=-1) + + def fwd(): + if input.kernel_provider == "liger": + return liger_jsd(_input, target) + else: + return torch_jsd(_input, target) + + def full(): + y = fwd() + y.backward(retain_graph=True) + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + gpu_memory_gbs = get_total_gpu_memory() + # We know that the full test will require 54GBs for vocab size 2^17 on torch + if gpu_memory_gbs >= 54: + x_max = 17 + else: + x_max = 16 + common_args = { + "kernel_name": "jsd", + "x_name": "V", + "x_label": "vocab size", + "x_values": [2**i for i in range(12, x_max + 1)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [{"B": 4, "T": 2048}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_memory_jsd, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_args, + ) + + run_benchmarks( + bench_test_fn=bench_speed_jsd, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_args, + ) diff --git a/benchmark/scripts/benchmark_kl_div.py b/benchmark/scripts/benchmark_kl_div.py new file mode 100755 index 0000000000000000000000000000000000000000..09948c38b48e0d4b9c549ea374064af833dab0da --- /dev/null +++ b/benchmark/scripts/benchmark_kl_div.py @@ -0,0 +1,117 @@ +import torch +import torch.nn as nn +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.transformers.kl_div import LigerKLDIVLoss +from liger_kernel.utils import infer_device + +device = infer_device() + +S, E = 12, 18 + + +def bench_speed_kldiv(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + reduction = "batchmean" + V = input.x + B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] + torch_kl_div = nn.KLDivLoss(reduction=reduction) + liger_kl_div = LigerKLDIVLoss(reduction=reduction) + + _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(dim=-1) + target = torch.randn(B * T, V, device=device).softmax(dim=-1) + + def fwd(): + if input.kernel_provider == "liger": + return liger_kl_div(_input, target) + else: + return torch_kl_div(_input, target) + + if input.kernel_operation_mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100) + elif input.kernel_operation_mode == "backward": + y = fwd() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + quantiles=QUANTILES, + grad_to_none=[_input], + rep=100, + ) + elif input.kernel_operation_mode == "full": + + def full(): + y = fwd() + y.backward(retain_graph=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100) + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +def bench_memory_kldiv(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + reduction = "batchmean" + torch_kl_div = nn.KLDivLoss(reduction=reduction) + liger_kl_div = LigerKLDIVLoss(reduction=reduction) + + V = input.x + B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] + + _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(dim=-1) + target = torch.randn(B * T, V, device=device).softmax(dim=-1) + + def fwd(): + if input.kernel_provider == "liger": + return liger_kl_div(_input, target) + else: + return torch_kl_div(_input, target) + + def full(): + y = fwd() + y.backward(retain_graph=True) + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + common_args = { + "kernel_name": "kl_div", + "x_name": "V", + "x_label": "vocab size", + "x_values": [2**i for i in range(12, 18)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [{"B": 8, "T": 512}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_memory_kldiv, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_args, + ) + + run_benchmarks( + bench_test_fn=bench_speed_kldiv, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_args, + ) diff --git a/benchmark/scripts/benchmark_kto_loss.py b/benchmark/scripts/benchmark_kto_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..bbde1d5c6b5749e4ce8067231ee3c7a15fb47e7e --- /dev/null +++ b/benchmark/scripts/benchmark_kto_loss.py @@ -0,0 +1,314 @@ +import os +import sys + +import torch +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.chunked_loss import LigerFusedLinearKTOLoss +from liger_kernel.utils import infer_device + +device = infer_device() +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + + +class TorchLMHeadKTO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + use_bias: bool = False, + use_ref_bias: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + ): + from test.chunked_loss.test_kto_loss import HFKTOLoss + + super().__init__() + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=use_bias, dtype=dtype) + self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=use_ref_bias, dtype=dtype) + self.KTO_loss = HFKTOLoss( + ignore_index=ignore_index, + beta=beta, + use_ref_model=True, + ).get_batch_loss_metrics + + def forward(self, x, ref_x, y, preference_labels, kl=None): + return self.KTO_loss( + weight=self.lin.weight, + _input=x, + target=y, + bias=self.lin.bias, + ref_input=ref_x, + ref_weight=self.ref_lin.weight, + ref_bias=self.ref_lin.bias, + preference_labels=preference_labels, + kl=kl, + ) + + +class LigerLMHeadKTO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + use_bias: bool = False, + use_ref_bias: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + ): + super().__init__() + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=use_bias, dtype=dtype) + self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=use_ref_bias, dtype=dtype) + self.KTO_loss = LigerFusedLinearKTOLoss( + ignore_index=ignore_index, + beta=beta, + use_ref_model=True, + ) + + def forward(self, x, ref_x, y, preference_labels, kl=None): + return self.KTO_loss( + _input=x, + lin_weight=self.lin.weight, + target=y, + preference_labels=preference_labels, + bias=self.lin.bias, + ref_input=ref_x, + ref_weight=self.ref_lin.weight, + ref_bias=self.ref_lin.bias, + kl=kl, + ) + + +def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + B = input.x + T = input.extra_benchmark_config["T"] + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + bias = input.extra_benchmark_config["bias"] + beta = input.extra_benchmark_config["beta"] + ignore_index = input.extra_benchmark_config["ignore_index"] + provider = input.kernel_provider + + torch_kto_loss = TorchLMHeadKTO( + H=H, + V=V, + dtype=dtype, + use_bias=bias, + use_ref_bias=bias, + ignore_index=ignore_index, + beta=beta, + ).to(device) + + liger_kto_loss = LigerLMHeadKTO( + H=H, + V=V, + dtype=dtype, + use_bias=bias, + use_ref_bias=bias, + ignore_index=ignore_index, + beta=beta, + ).to(device) + + # Input shape: [B, T, H] + _input = torch.randn(B, T, H, device=device, dtype=dtype) + + # Target shape: [B, T] + target = torch.randint(V, (B, T), dtype=torch.long, device=device) + + # Preference labels shape: [B] + # Create binary preference labels (0 or 1) for each sequence in the batch + # Used to indicate preferred sequences (1) vs non-preferred sequences (0) + preference_labels = torch.randint(2, (B,), dtype=torch.bool, device=device) + + # Precomputed KL divergence between policy and reference distributions + kl = torch.randn(1, device=device, dtype=dtype) + + # Add ignore_index tokens to simulate padding + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + # Add ref_x with the same shape as _input + ref_input = torch.randn(B, T, H, device=device, dtype=dtype) + + def fwd(): + if provider == "liger": + return liger_kto_loss( + x=_input, + ref_x=ref_input, + y=target, + preference_labels=preference_labels, + kl=kl, + )[0] + elif provider == "huggingface": + return torch_kto_loss( + x=_input, + ref_x=ref_input, + y=target, + preference_labels=preference_labels, + kl=kl, + )[0] + + def full(): + y = fwd() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + B = input.x + T = input.extra_benchmark_config["T"] + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + bias = input.extra_benchmark_config["bias"] + beta = input.extra_benchmark_config["beta"] + ignore_index = input.extra_benchmark_config["ignore_index"] + provider = input.kernel_provider + mode = input.kernel_operation_mode + + torch_kto_loss = TorchLMHeadKTO( + H=H, + V=V, + dtype=dtype, + beta=beta, + ignore_index=ignore_index, + use_bias=bias, + ).to(device) + liger_kto_loss = LigerLMHeadKTO( + H=H, + V=V, + dtype=dtype, + beta=beta, + ignore_index=ignore_index, + use_bias=bias, + ).to(device) + + # Input shape: [B, T, H] + _input = torch.randn(B, T, H, device=device, dtype=dtype) + + # Target shape: [B, T] + target = torch.randint(V, (B, T), device=device, dtype=torch.long) + + # Preference labels shape: [B] + # Create binary preference labels (0 or 1) for each sequence in the batch + # Used to indicate preferred sequences (1) vs non-preferred sequences (0) + preference_labels = torch.randint(2, (B,), dtype=torch.bool, device=device) + + # Precomputed KL divergence between policy and reference distributions + kl = torch.randn(1, device=device, dtype=dtype) + + # Add ignore_index tokens + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + # Add ref_x with the same shape as _input + ref_input = torch.randn(B, T, H, device=device, dtype=dtype) + + def fwd(): + if provider == "liger": + return liger_kto_loss( + x=_input, + ref_x=ref_input, + y=target, + preference_labels=preference_labels, + kl=kl, + )[0] + elif provider == "huggingface": + return torch_kto_loss( + x=_input, + ref_x=ref_input, + y=target, + preference_labels=preference_labels, + kl=kl, + )[0] + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[_input], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + rep=100, + quantiles=QUANTILES, + ) + + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "kto_loss", + "x_name": "B", + "x_label": "Batch Size (B)", + "x_values": [2**i for i in range(1, 6)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "T": 512, + "H": 1024, + "V": 128256, + "mode": "forward", + "dtype": torch.bfloat16, + "bias": True, + "beta": 0.1, + "ignore_index": 42, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_kto_loss, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + + run_benchmarks( + bench_test_fn=bench_memory_kto_loss, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_layer_norm.py b/benchmark/scripts/benchmark_layer_norm.py new file mode 100755 index 0000000000000000000000000000000000000000..0addf78ed8fae60d04869e67f317dac04f3df047 --- /dev/null +++ b/benchmark/scripts/benchmark_layer_norm.py @@ -0,0 +1,125 @@ +import torch +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.transformers.layer_norm import LigerLayerNorm +from liger_kernel.utils import infer_device + +device = infer_device() + + +def bench_speed_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + N = input.x + provider = input.kernel_provider + mode = input.kernel_operation_mode + extra_benchmark_config = input.extra_benchmark_config + M = extra_benchmark_config["M"] + eps = extra_benchmark_config["eps"] + dtype = extra_benchmark_config["dtype"] + + x_shape = (M, N) + triton_ln = LigerLayerNorm(hidden_size=N).to(device) + torch_ln = torch.nn.LayerNorm(N, eps=eps).to(device) + + x = torch.randn(x_shape, dtype=dtype, device=device) + dy = torch.randn_like(x) + x.requires_grad_(True) + + def y_fwd(): + if provider == "liger": + return triton_ln(x) + if provider == "huggingface": + return torch_ln(x) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500) + elif mode == "backward": + y = y_fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(dy, retain_graph=True), + quantiles=QUANTILES, + grad_to_none=[x], + rep=500, + ) + elif mode == "full": + + def full(): + y = y_fwd() + y.backward(dy, retain_graph=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=[x], rep=500) + + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +def bench_memory_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + N = input.x + provider = input.kernel_provider + dtype = input.extra_benchmark_config["dtype"] + M = input.extra_benchmark_config["M"] + eps = input.extra_benchmark_config["eps"] + + x_shape = (M, N) + + triton_ln = LigerLayerNorm(hidden_size=N).to(device) + torch_ln = torch.nn.LayerNorm(N, eps=eps).to(device) + + x = torch.randn(x_shape, dtype=dtype, device=device) + dy = torch.randn_like(x) + x.requires_grad_(True) + + def y_fwd(): + if provider == "liger": + return triton_ln(x) + if provider == "huggingface": + return torch_ln(x) + + def full(): + y = y_fwd() + y.backward(dy, retain_graph=True) + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "layer_norm", + "x_name": "N", + "x_label": "hidden size", + "x_values": [2**i for i in range(10, 15)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [{"M": 4096, "dtype": torch.float32, "eps": 1e-6}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_layer_norm, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_layer_norm, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_llama4_rope.py b/benchmark/scripts/benchmark_llama4_rope.py new file mode 100755 index 0000000000000000000000000000000000000000..47d06051e4034ed06d24ea8eb08ec2014135560f --- /dev/null +++ b/benchmark/scripts/benchmark_llama4_rope.py @@ -0,0 +1,245 @@ +import torch +import triton + +from transformers.models.llama4.configuration_llama4 import Llama4TextConfig +from transformers.models.llama4.modeling_llama4 import Llama4TextRotaryEmbedding +from transformers.models.llama4.modeling_llama4 import apply_rotary_emb +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.transformers.llama4_rope import liger_llama4_text_rotary_pos_emb +from liger_kernel.utils import infer_device +from liger_kernel.utils import transformers_version_dispatch + +device = infer_device() + + +def bench_speed_llama4_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + provider = input.kernel_provider + mode = input.kernel_operation_mode + + extra_benchmark_config = input.extra_benchmark_config + num_q_heads = extra_benchmark_config["num_q_heads"] + num_kv_heads = extra_benchmark_config["num_kv_heads"] + dtype = extra_benchmark_config["dtype"] + + # x can be either hidden_size or seq_len + hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x + seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x + + head_dim = hidden_size // num_q_heads + + # Create Llama4TextConfig for the rotary embedding + config = Llama4TextConfig( + hidden_size=hidden_size, + num_attention_heads=num_q_heads, + num_key_value_heads=num_kv_heads, + head_dim=head_dim, + max_position_embeddings=seq_len, + ) + + rotary_emb = transformers_version_dispatch( + "4.48.0", + Llama4TextRotaryEmbedding, + Llama4TextRotaryEmbedding, + before_kwargs={"config": config, "device": device}, + after_kwargs={"config": config, "device": device}, + ) + + q = torch.randn( + (1, seq_len, num_q_heads, head_dim), + device=device, + requires_grad=True, + dtype=dtype, + ) + k = torch.randn( + (1, seq_len, num_kv_heads, head_dim), + device=device, + requires_grad=True, + dtype=dtype, + ) + dq, dk = ( + torch.randn_like(q, device=device, dtype=dtype), + torch.randn_like(k, device=device), + ) + pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) + freqs_cis = rotary_emb(q, pos_ids) + + def fwd(): + if provider == "liger": + return liger_llama4_text_rotary_pos_emb(q, k, freqs_cis) + elif provider == "huggingface": + return apply_rotary_emb(q, k, freqs_cis) + else: + raise ValueError(f"Invalid provider: {provider} for Llama4 RoPE embedding") + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + grad_to_none=[q, k], + rep=400, + quantiles=QUANTILES, + ) + elif mode == "backward": + q_out, k_out = fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True), + grad_to_none=[q, k], + rep=400, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + q_out, k_out = fwd() + torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + grad_to_none=[q, k], + rep=400, + quantiles=QUANTILES, + ) + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +def bench_memory_llama4_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + provider = input.kernel_provider + + extra_benchmark_config = input.extra_benchmark_config + num_q_heads = extra_benchmark_config["num_q_heads"] + num_kv_heads = extra_benchmark_config["num_kv_heads"] + dtype = extra_benchmark_config["dtype"] + + # x can be either hidden_size or seq_len + hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x + seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x + + head_dim = hidden_size // num_q_heads + + # Create Llama4TextConfig for the rotary embedding + config = Llama4TextConfig( + hidden_size=hidden_size, + num_attention_heads=num_q_heads, + num_key_value_heads=num_kv_heads, + head_dim=head_dim, + max_position_embeddings=seq_len, + ) + + rotary_emb = transformers_version_dispatch( + "4.48.0", + Llama4TextRotaryEmbedding, + Llama4TextRotaryEmbedding, + before_kwargs={"config": config, "device": device}, + after_kwargs={"config": config, "device": device}, + ) + + q = torch.randn( + (1, seq_len, num_q_heads, head_dim), + device=device, + requires_grad=True, + dtype=dtype, + ) + k = torch.randn( + (1, seq_len, num_kv_heads, head_dim), + device=device, + requires_grad=True, + dtype=dtype, + ) + dq, dk = ( + torch.randn_like(q, device=device, dtype=dtype), + torch.randn_like(k, device=device), + ) + pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) + freqs_cis = rotary_emb(q, pos_ids) + + def full(): + if provider == "liger": + q_out, k_out = liger_llama4_text_rotary_pos_emb(q, k, freqs_cis) + else: + q_out, k_out = apply_rotary_emb(q, k, freqs_cis) + torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True) + + mem_50, mem_20, mem_80 = _test_memory( + full, + quantiles=QUANTILES, + ) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs_varying_hidden_size = { + "kernel_name": "llama4_rope", + "x_name": "H", + "x_label": "hidden size", + "x_values": [32 * (2**i) for i in range(4, 10, 2)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "dtype": torch.bfloat16, + "seq_len": 2048, + "num_q_heads": 32, + "num_kv_heads": 8, + } + ], + "overwrite": args.overwrite, + } + run_benchmarks( + bench_test_fn=bench_speed_llama4_rope, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs_varying_hidden_size, + ) + run_benchmarks( + bench_test_fn=bench_memory_llama4_rope, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs_varying_hidden_size, + ) + + common_configs_varying_seq_len = { + "kernel_name": "llama4_rope", + "x_name": "T", + "x_label": "sequence length", + "x_values": [2**i for i in range(10, 15)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "dtype": torch.bfloat16, + "hidden_size": 8192, + "num_q_heads": 32, + "num_kv_heads": 8, + } + ], + "overwrite": args.overwrite, + } + run_benchmarks( + bench_test_fn=bench_speed_llama4_rope, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs_varying_seq_len, + ) + run_benchmarks( + bench_test_fn=bench_memory_llama4_rope, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs_varying_seq_len, + ) diff --git a/benchmark/scripts/benchmark_mhc.py b/benchmark/scripts/benchmark_mhc.py new file mode 100755 index 0000000000000000000000000000000000000000..47cdd6336879510244f2433285c9ec6fe9fcb449 --- /dev/null +++ b/benchmark/scripts/benchmark_mhc.py @@ -0,0 +1,255 @@ +import os +import sys + +import torch +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.transformers.functional import liger_mhc_coeffs +from liger_kernel.transformers.functional import liger_mhc_post_res +from liger_kernel.transformers.functional import liger_mhc_pre +from liger_kernel.utils import infer_device + +device = infer_device() + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + + +def bench_speed_mhc(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + from test.transformers.test_mhc import mhc_coeffs_ref + + T = input.x + B = input.extra_benchmark_config["B"] + HC = input.extra_benchmark_config["HC"] + C = input.extra_benchmark_config["C"] + sub_kernel = input.extra_benchmark_config["sub_kernel"] + tmax = input.extra_benchmark_config["tmax"] + rms_eps = input.extra_benchmark_config["rms_eps"] + pre_eps = input.extra_benchmark_config["pre_eps"] + sinkhorn_eps = input.extra_benchmark_config["sinkhorn_eps"] + post_mult = input.extra_benchmark_config["post_mult"] + provider = input.kernel_provider + mode = input.kernel_operation_mode + + coeffs_cfg = dict(tmax=tmax, rms_eps=rms_eps, pre_eps=pre_eps, sinkhorn_eps=sinkhorn_eps, post_mult=post_mult) + need_grad = mode in ("backward", "full") + + x = torch.randn(B, T, HC, C, device=device, dtype=torch.bfloat16, requires_grad=need_grad) + K, M = HC * C, HC * HC + 2 * HC + phi = (torch.randn(K, M, device=device, dtype=torch.bfloat16) * 0.02).requires_grad_(need_grad) + b_param = torch.zeros(M, device=device, dtype=torch.float32, requires_grad=need_grad) + alpha_pre = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=need_grad) + alpha_post = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=need_grad) + alpha_res = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=need_grad) + + grad_to_none = [x, phi, b_param, alpha_pre, alpha_post, alpha_res] if need_grad else None + + if sub_kernel == "coeffs": + + def fwd(): + if provider == "liger": + return liger_mhc_coeffs(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg) + return mhc_coeffs_ref(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg) + + def fwd_loss(): + h_pre, h_post, h_res = fwd() + return h_pre.square().mean() + h_post.square().mean() + h_res.square().mean() + + elif sub_kernel == "pre": + with torch.no_grad(): + h_pre_c, _, _ = liger_mhc_coeffs( + x.detach(), + phi.detach(), + b_param.detach(), + alpha_pre.detach(), + alpha_post.detach(), + alpha_res.detach(), + **coeffs_cfg, + ) + h_pre_c.requires_grad_(need_grad) + grad_to_none = [x, h_pre_c] if need_grad else None + + def fwd(): + if provider == "liger": + return liger_mhc_pre(x, h_pre_c) + return (x.float() * h_pre_c.unsqueeze(-1)).sum(dim=-2) + + def fwd_loss(): + return fwd().square().mean() + + elif sub_kernel == "post_res": + with torch.no_grad(): + _, h_post_c, h_res_c = liger_mhc_coeffs( + x.detach(), + phi.detach(), + b_param.detach(), + alpha_pre.detach(), + alpha_post.detach(), + alpha_res.detach(), + **coeffs_cfg, + ) + h_post_c.requires_grad_(need_grad) + h_res_c.requires_grad_(need_grad) + f_out = torch.randn(B, T, C, device=device, dtype=torch.bfloat16, requires_grad=need_grad) + grad_to_none = [x, f_out, h_post_c, h_res_c] if need_grad else None + + def fwd(): + if provider == "liger": + return liger_mhc_post_res(x, f_out, h_post_c, h_res_c) + return torch.einsum("...oi,...ic->...oc", h_res_c, x.float()) + h_post_c.unsqueeze( + -1 + ) * f_out.float().unsqueeze(-2) + + def fwd_loss(): + return fwd().square().mean() + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) + elif mode == "backward": + y = fwd_loss() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=grad_to_none, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd_loss() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=grad_to_none, rep=100, quantiles=QUANTILES) + + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_mhc(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + from test.transformers.test_mhc import mhc_coeffs_ref + + T = input.x + B = input.extra_benchmark_config["B"] + HC = input.extra_benchmark_config["HC"] + C = input.extra_benchmark_config["C"] + sub_kernel = input.extra_benchmark_config["sub_kernel"] + tmax = input.extra_benchmark_config["tmax"] + rms_eps = input.extra_benchmark_config["rms_eps"] + pre_eps = input.extra_benchmark_config["pre_eps"] + sinkhorn_eps = input.extra_benchmark_config["sinkhorn_eps"] + post_mult = input.extra_benchmark_config["post_mult"] + provider = input.kernel_provider + + coeffs_cfg = dict(tmax=tmax, rms_eps=rms_eps, pre_eps=pre_eps, sinkhorn_eps=sinkhorn_eps, post_mult=post_mult) + + x = torch.randn(B, T, HC, C, device=device, dtype=torch.bfloat16, requires_grad=True) + K, M = HC * C, HC * HC + 2 * HC + phi = (torch.randn(K, M, device=device, dtype=torch.bfloat16) * 0.02).requires_grad_(True) + b_param = torch.zeros(M, device=device, dtype=torch.float32, requires_grad=True) + alpha_pre = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True) + alpha_post = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True) + alpha_res = torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True) + + if sub_kernel == "coeffs": + + def full(): + if provider == "liger": + hp, hpo, hr = liger_mhc_coeffs(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg) + else: + hp, hpo, hr = mhc_coeffs_ref(x, phi, b_param, alpha_pre, alpha_post, alpha_res, **coeffs_cfg) + (hp.square().mean() + hpo.square().mean() + hr.square().mean()).backward() + + elif sub_kernel == "pre": + with torch.no_grad(): + h_pre_c, _, _ = liger_mhc_coeffs( + x.detach(), + phi.detach(), + b_param.detach(), + alpha_pre.detach(), + alpha_post.detach(), + alpha_res.detach(), + **coeffs_cfg, + ) + h_pre_c.requires_grad_(True) + + def full(): + if provider == "liger": + out = liger_mhc_pre(x, h_pre_c) + else: + out = (x.float() * h_pre_c.unsqueeze(-1)).sum(dim=-2) + out.square().mean().backward() + + elif sub_kernel == "post_res": + with torch.no_grad(): + _, h_post_c, h_res_c = liger_mhc_coeffs( + x.detach(), + phi.detach(), + b_param.detach(), + alpha_pre.detach(), + alpha_post.detach(), + alpha_res.detach(), + **coeffs_cfg, + ) + h_post_c.requires_grad_(True) + h_res_c.requires_grad_(True) + f_out = torch.randn(B, T, C, device=device, dtype=torch.bfloat16, requires_grad=True) + + def full(): + if provider == "liger": + out = liger_mhc_post_res(x, f_out, h_post_c, h_res_c) + else: + out = torch.einsum("...oi,...ic->...oc", h_res_c, x.float()) + h_post_c.unsqueeze( + -1 + ) * f_out.float().unsqueeze(-2) + out.square().mean().backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + for sub_kernel in ["coeffs", "pre", "post_res"]: + common_configs = { + "kernel_name": f"mhc_{sub_kernel}", + "x_name": "T", + "x_label": "Sequence Length (T)", + "x_values": [2**i for i in range(7, 12)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "B": 4, + "HC": 4, + "C": 4096, + "tmax": 20, + "rms_eps": 1e-6, + "pre_eps": 0.0, + "sinkhorn_eps": 1e-6, + "post_mult": 2.0, + "sub_kernel": sub_kernel, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_mhc, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + + run_benchmarks( + bench_test_fn=bench_memory_mhc, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_mhc_lm.py b/benchmark/scripts/benchmark_mhc_lm.py new file mode 100755 index 0000000000000000000000000000000000000000..6330a0e1a51da94bc7994731bc1bac9344e600d0 --- /dev/null +++ b/benchmark/scripts/benchmark_mhc_lm.py @@ -0,0 +1,455 @@ +import os +import sys + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.transformers.mhc import LigerMHC +from liger_kernel.utils import infer_device + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + +device = infer_device() + + +class RMSNorm(nn.Module): + def __init__(self, hidden_size: int, *, eps: float, dtype: torch.dtype, device: str): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + var = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(var + self.eps) + return x * self.weight + + +def _build_rope_cache(seq_len: int, head_dim: int, *, device: torch.device, dtype: torch.dtype): + inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2, device=device, dtype=torch.float32) / head_dim)) + positions = torch.arange(seq_len, device=device, dtype=torch.float32) + freqs = torch.einsum("i,j->ij", positions, inv_freq) + cos = freqs.cos().to(dtype) + sin = freqs.sin().to(dtype) + return cos, sin + + +def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + x1 = x[..., ::2] + x2 = x[..., 1::2] + cos = cos[None, None, :, :] + sin = sin[None, None, :, :] + return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) + + +class MiniLlamaAttention(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, *, dtype: torch.dtype, device: str): + super().__init__() + assert hidden_size % num_heads == 0 + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + assert self.head_dim % 2 == 0, "head_dim must be even for RoPE" + + self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False, dtype=dtype, device=device) + self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False, dtype=dtype, device=device) + self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False, dtype=dtype, device=device) + self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + bsz, seq_len, _ = x.shape + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + + q = q.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + + cos, sin = _build_rope_cache(seq_len, self.head_dim, device=x.device, dtype=q.dtype) + q = _apply_rope(q, cos, sin) + k = _apply_rope(k, cos, sin) + + attn = F.scaled_dot_product_attention(q, k, v, is_causal=True) + attn = attn.transpose(1, 2).contiguous().view(bsz, seq_len, self.hidden_size) + return self.o_proj(attn) + + +class MiniLlamaMLP(nn.Module): + def __init__(self, hidden_size: int, intermediate_mult: int, *, dtype: torch.dtype, device: str): + super().__init__() + intermediate_size = hidden_size * intermediate_mult + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class AttentionBlock(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, *, dtype: torch.dtype, device: str): + super().__init__() + self.norm = RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + self.attn = MiniLlamaAttention(hidden_size, num_heads, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.attn(self.norm(x)) + + +class MLPBlock(nn.Module): + def __init__(self, hidden_size: int, intermediate_mult: int, *, dtype: torch.dtype, device: str): + super().__init__() + self.norm = RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + self.mlp = MiniLlamaMLP(hidden_size, intermediate_mult, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mlp(self.norm(x)) + + +class TorchMHC(nn.Module): + def __init__( + self, + layer: nn.Module, + *, + hc: int, + c: int, + tmax: int, + rms_eps: float, + pre_eps: float, + sinkhorn_eps: float, + post_mult: float, + phi_dtype: torch.dtype, + ): + super().__init__() + self.layer = layer + self.hc = int(hc) + self.c = int(c) + self.tmax = int(tmax) + self.rms_eps = float(rms_eps) + self.pre_eps = float(pre_eps) + self.sinkhorn_eps = float(sinkhorn_eps) + self.post_mult = float(post_mult) + + layer_param = next(layer.parameters()) + device = layer_param.device + + m = hc * hc + 2 * hc + k = hc * c + self.phi = nn.Parameter(torch.randn(k, m, dtype=phi_dtype, device=device) * 0.02) + self.b = nn.Parameter(torch.zeros(m, dtype=torch.float32, device=device)) + self.alpha_pre = nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device)) + self.alpha_post = nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device)) + self.alpha_res = nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device)) + + self.layer_dtype = layer_param.dtype + + def _coeffs(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + from test.transformers.test_mhc import mhc_coeffs_ref + + return mhc_coeffs_ref( + x, + self.phi, + self.b, + self.alpha_pre, + self.alpha_post, + self.alpha_res, + tmax=self.tmax, + rms_eps=self.rms_eps, + pre_eps=self.pre_eps, + sinkhorn_eps=self.sinkhorn_eps, + post_mult=self.post_mult, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_pre, h_post, h_res = self._coeffs(x) + x_in = (x.float() * h_pre.unsqueeze(-1)).sum(dim=-2) + if x_in.dtype != self.layer_dtype: + x_in = x_in.to(self.layer_dtype) + f_out = self.layer(x_in) + x_out = torch.einsum("...oi,...ic->...oc", h_res, x.float()) + h_post.unsqueeze(-1) * f_out.float().unsqueeze( + -2 + ) + return x_out.to(x.dtype) + + +class MHCDecoderLayer(nn.Module): + def __init__( + self, + mhc_cls: type[nn.Module], + *, + hidden_size: int, + hc: int, + num_heads: int, + intermediate_mult: int, + tmax: int, + dtype: torch.dtype, + device: str, + ): + super().__init__() + attn = AttentionBlock(hidden_size, num_heads, dtype=dtype, device=device) + mlp = MLPBlock(hidden_size, intermediate_mult, dtype=dtype, device=device) + self.attn = mhc_cls( + attn, + hc=hc, + c=hidden_size, + tmax=tmax, + rms_eps=1e-6, + pre_eps=1e-4, + sinkhorn_eps=1e-6, + post_mult=2.0, + phi_dtype=dtype, + ) + self.mlp = mhc_cls( + mlp, + hc=hc, + c=hidden_size, + tmax=tmax, + rms_eps=1e-6, + pre_eps=1e-4, + sinkhorn_eps=1e-6, + post_mult=2.0, + phi_dtype=dtype, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.attn(x) + x = self.mlp(x) + return x + + +class BenchMiniMHCLM(nn.Module): + def __init__( + self, + mhc_cls: type[nn.Module], + *, + vocab_size: int, + hidden_size: int, + hc: int, + num_layers: int, + num_heads: int, + intermediate_mult: int, + tmax: int, + dtype: torch.dtype, + device: str, + ): + super().__init__() + self.hc = hc + self.hidden_size = hidden_size + self.embed = nn.Embedding(vocab_size, hc * hidden_size, dtype=dtype, device=device) + self.layers = nn.ModuleList( + [ + MHCDecoderLayer( + mhc_cls, + hidden_size=hidden_size, + hc=hc, + num_heads=num_heads, + intermediate_mult=intermediate_mult, + tmax=tmax, + dtype=dtype, + device=device, + ) + for _ in range(num_layers) + ] + ) + self.final_norm = RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False, dtype=dtype, device=device) + + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + x = self.embed(input_ids) + bsz, seq_len, _ = x.shape + x = x.view(bsz, seq_len, self.hc, self.hidden_size) + for layer in self.layers: + x = layer(x) + x = x.mean(dim=-2) + x = self.final_norm(x) + return self.lm_head(x) + + +def _build_model( + provider: str, + *, + hidden_size: int, + hc: int, + num_layers: int, + num_heads: int, + intermediate_mult: int, + vocab_size: int, + tmax: int, + dtype: torch.dtype, +): + mhc_cls = LigerMHC if provider == "liger" else TorchMHC + return BenchMiniMHCLM( + mhc_cls, + vocab_size=vocab_size, + hidden_size=hidden_size, + hc=hc, + num_layers=num_layers, + num_heads=num_heads, + intermediate_mult=intermediate_mult, + tmax=tmax, + dtype=dtype, + device=device, + ) + + +def bench_speed_mhc_lm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + hidden_size = int(input.x) + provider = input.kernel_provider + mode = input.kernel_operation_mode + extra = input.extra_benchmark_config + bsz = extra["B"] + seq_len = extra["T"] + hc = extra["HC"] + num_layers = extra["layers"] + num_heads = extra["heads"] + vocab_size = extra["vocab"] + dtype = extra["dtype"] + tmax = extra["tmax"] + intermediate_mult = extra["intermediate_mult"] + + if hidden_size % num_heads != 0: + raise ValueError("hidden_size must be divisible by num_heads") + + model = _build_model( + provider, + hidden_size=hidden_size, + hc=hc, + num_layers=num_layers, + num_heads=num_heads, + intermediate_mult=intermediate_mult, + vocab_size=vocab_size, + tmax=tmax, + dtype=dtype, + ) + + input_ids = torch.randint(0, vocab_size, (bsz, seq_len), device=device) + + def fwd(): + return model(input_ids) + + def fwd_loss(): + return fwd().float().mean() + + grad_to_none = list(model.parameters()) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, grad_to_none=grad_to_none, rep=100) + elif mode == "backward": + loss = fwd_loss() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: loss.backward(retain_graph=True), + quantiles=QUANTILES, + grad_to_none=grad_to_none, + rep=100, + ) + elif mode == "full": + + def full(): + loss = fwd_loss() + loss.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=grad_to_none, rep=100) + else: + raise ValueError(f"Unknown mode: {mode}") + + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +def bench_memory_mhc_lm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + hidden_size = int(input.x) + provider = input.kernel_provider + extra = input.extra_benchmark_config + bsz = extra["B"] + seq_len = extra["T"] + hc = extra["HC"] + num_layers = extra["layers"] + num_heads = extra["heads"] + vocab_size = extra["vocab"] + dtype = extra["dtype"] + tmax = extra["tmax"] + intermediate_mult = extra["intermediate_mult"] + + if hidden_size % num_heads != 0: + raise ValueError("hidden_size must be divisible by num_heads") + + model = _build_model( + provider, + hidden_size=hidden_size, + hc=hc, + num_layers=num_layers, + num_heads=num_heads, + intermediate_mult=intermediate_mult, + vocab_size=vocab_size, + tmax=tmax, + dtype=dtype, + ) + + input_ids = torch.randint(0, vocab_size, (bsz, seq_len), device=device) + + def fwd(): + return model(input_ids) + + def full(): + loss = fwd().float().mean() + loss.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "mhc_llama_like_lm", + "x_name": "hidden_size", + "x_label": "hidden_size", + "x_values": [256, 512, 1024], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "B": 2, + "T": 256, + "HC": 4, + "layers": 2, + "heads": 8, + "vocab": 4096, + "dtype": torch.bfloat16, + "tmax": 8, + "intermediate_mult": 4, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_mhc_lm, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_mhc_lm, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_model_configs.py b/benchmark/scripts/benchmark_model_configs.py new file mode 100755 index 0000000000000000000000000000000000000000..630b0d555e7050fca83ae1c856fd97ac4812a8bb --- /dev/null +++ b/benchmark/scripts/benchmark_model_configs.py @@ -0,0 +1,258 @@ +""" +Standardized benchmark model configurations. + +Provides canonical model architecture profiles and device-specific benchmark +parameters. All benchmark scripts should derive their tensor shapes from these +shared configs rather than defining ad-hoc per-script constants. + +Usage:: + + from benchmark_model_configs import ( + get_benchmark_model_config, + compute_seq_len_sweep_config, + estimate_kernel_peak_memory, + ) + + args = parse_benchmark_script_args() + model = get_benchmark_model_config(args.model) + + # Measure actual memory via a small probe, then compute sweep config + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + bpt = peak_bytes // probe_num_tokens + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=bpt) +""" + +import gc +import math + +from dataclasses import dataclass +from typing import Callable +from typing import Dict +from typing import Optional + +import torch + +from liger_kernel.utils import get_total_gpu_memory +from liger_kernel.utils import infer_device + + +@dataclass(frozen=True) +class ModelConfig: + """Canonical model architecture profile. + + Each field corresponds to a standard LLM hyperparameter. Benchmark scripts + pick the fields they need (e.g. hidden_size for RMSNorm, vocab_size for + CrossEntropy) while kernel-specific overrides (e.g. hidden_act for GEGLU) + are applied locally in the benchmark script. + """ + + name: str + hidden_size: int + intermediate_size: int + vocab_size: int + num_attention_heads: int + num_key_value_heads: int + head_dim: int + hidden_act: str + max_position_embeddings: int = 8192 + rms_norm_eps: float = 1e-5 + dtype: torch.dtype = torch.bfloat16 + + +@dataclass(frozen=True) +class SeqLenSweepConfig: + """Config for benchmarks that sweep sequence length (e.g. GEGLU, SwiGLU). + + Attributes: + batch_size: Safe batch size for the sweep. + seq_len: Max sequence length (upper bound for x_values). + """ + + batch_size: int + seq_len: int + + +@dataclass(frozen=True) +class HiddenSizeSweepConfig: + """Config for benchmarks that sweep hidden_size with fixed BT (e.g. DyT). + + Attributes: + bt: Fixed batch * seq dimension. + max_hidden_size: Upper bound for hidden_size sweep. + """ + + bt: int + max_hidden_size: int + + +# ── Model Profiles ────────────────────────────────────────────────────────── + +LLAMA_2_7B = ModelConfig( + name="llama_2_7b", + hidden_size=4096, + intermediate_size=11008, + vocab_size=32000, + num_attention_heads=32, + num_key_value_heads=32, + head_dim=128, + hidden_act="silu", + max_position_embeddings=4096, +) + +LLAMA_3_8B = ModelConfig( + name="llama_3_8b", + hidden_size=4096, + intermediate_size=14336, + vocab_size=128256, + num_attention_heads=32, + num_key_value_heads=8, + head_dim=128, + hidden_act="silu", + max_position_embeddings=8192, +) + +MODEL_REGISTRY: Dict[str, ModelConfig] = { + "llama_2_7b": LLAMA_2_7B, + "llama_3_8b": LLAMA_3_8B, +} + +DEFAULT_MODEL_CONFIG = LLAMA_3_8B + + +def get_benchmark_model_config(model_name: Optional[str] = None) -> ModelConfig: + """Resolve benchmark model config from name. + + Returns the canonical model architecture profile (hidden_size, vocab_size, + dtype, etc.) for benchmark runs. Use this to obtain model attributes + when building benchmark tensors and shapes. + + Args: + model_name: Registry key (e.g. ``llama_2_7b``, ``llama_3_8b``). + If None, returns ``DEFAULT_MODEL_CONFIG``. + """ + return MODEL_REGISTRY[model_name] if model_name else DEFAULT_MODEL_CONFIG + + +def estimate_kernel_peak_memory(probe_fn: Callable[[], torch.Tensor]) -> int: + """Run a forward + backward probe to measure peak memory (bytes). + + Call this with the *pure PyTorch* (e.g. huggingface) implementation -- + that typically has the highest memory footprint and therefore gives a + safe upper-bound estimate. Returns the total peak bytes; divide by + num_tokens if you need bytes-per-token for :func:`compute_seq_len_sweep_config`. + + The probe_fn performs setup and forward pass internally; cleanup is + automatic, so callers do not need to manage tensor/layer lifecycle. + + Example:: + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // num_tokens # if needed + + Args: + probe_fn: Callable that performs setup, runs a forward pass, and + returns an output tensor suitable for ``.backward()``. + """ + device_str = infer_device() + torch_device_mod = getattr(torch, device_str) + + gc.collect() + torch_device_mod.empty_cache() + torch_device_mod.memory.reset_peak_memory_stats() + + y = probe_fn() + y.backward(torch.randn_like(y)) + + peak_bytes = torch_device_mod.max_memory_allocated() + + del y + gc.collect() + torch_device_mod.empty_cache() + + return max(1, peak_bytes) + + +def compute_seq_len_sweep_config( + model_cfg: ModelConfig, + kernel_bytes_per_token: Optional[int] = None, + memory_utilization: float = 0.4, + max_seq_len: Optional[int] = None, + max_batch_size: int = 32, +) -> SeqLenSweepConfig: + """Compute safe batch_size and seq_len for sequence-length sweep (e.g. GEGLU). + + Peak memory is estimated as + ``batch_size * seq_len * kernel_bytes_per_token`` and is capped at + device memory * memory_utilization. Device memory is obtained + internally via :func:`~liger_kernel.utils.get_total_gpu_memory`. + + Prefer obtaining *kernel_bytes_per_token* via + :func:`estimate_kernel_peak_memory` (divide by num_tokens) rather + than hardcoding an analytical estimate. + + Args: + model_cfg: Model architecture config. + kernel_bytes_per_token: Peak memory **per token** (``batch * seq_len`` + axis). Best obtained from :func:`estimate_kernel_peak_memory` / num_tokens. + Falls back to a conservative heuristic + (``hidden_size * dtype_bytes * 16``) when *None*. + memory_utilization: Fraction of total device memory to target (0 to 1). + Lower values are safer. Default ``0.4`` leaves headroom for + framework overhead and CUDA/NPU context. + max_seq_len: Hard upper bound for sequence length. Defaults to + ``model_cfg.max_position_embeddings`` so the sweep never exceeds + the model's native context window. + max_batch_size: Hard upper bound for batch size. + """ + total_memory_gb = get_total_gpu_memory() + dtype_bytes = 2 if model_cfg.dtype in (torch.bfloat16, torch.float16) else 4 + + if kernel_bytes_per_token is None: + kernel_bytes_per_token = model_cfg.hidden_size * dtype_bytes * 16 + + if max_seq_len is None: + max_seq_len = model_cfg.max_position_embeddings + + usable_bytes = total_memory_gb * (1024**3) * memory_utilization + max_tokens = max(1, int(usable_bytes / kernel_bytes_per_token)) + + seq_len = min(max_seq_len, max_tokens) + seq_len = 2 ** int(math.log2(seq_len)) if seq_len >= 1024 else 1024 + + batch_size = max(1, min(max_tokens // seq_len, max_batch_size)) + + return SeqLenSweepConfig(batch_size=batch_size, seq_len=seq_len) + + +def compute_hidden_size_sweep_config( + model_cfg: ModelConfig, + kernel_peak_bytes: int, + bt: int = 4096, + memory_utilization: float = 0.4, + max_hidden_size_multiplier: int = 4, +) -> HiddenSizeSweepConfig: + """Compute safe max_hidden_size for hidden_size sweep (e.g. DyT). + + For kernels with shape (BT, hidden_size) where BT is fixed and we sweep + hidden_size. Uses probe peak memory to derive max_hidden_size. + Device memory is obtained internally via :func:`~liger_kernel.utils.get_total_gpu_memory`. + + Args: + model_cfg: Model config. + kernel_peak_bytes: Peak memory from probe (BT, model.hidden_size). + bt: Fixed BT dimension; must match the probe. + memory_utilization: Fraction of device memory to use. + max_hidden_size_multiplier: Cap max_hidden_size at model.hidden_size * this. + """ + total_memory_gb = get_total_gpu_memory() + usable_bytes = total_memory_gb * (1024**3) * memory_utilization + kernel_bpt = max(1, kernel_peak_bytes // bt) + max_hidden_size = min( + model_cfg.hidden_size * max_hidden_size_multiplier, + max( + model_cfg.hidden_size, + int(usable_bytes * model_cfg.hidden_size / (bt * kernel_bpt)), + ), + ) + max_hidden_size = max(1024, 2 ** int(math.log2(max_hidden_size))) + return HiddenSizeSweepConfig(bt=bt, max_hidden_size=max_hidden_size) diff --git a/benchmark/scripts/benchmark_multi_token_attention.py b/benchmark/scripts/benchmark_multi_token_attention.py new file mode 100755 index 0000000000000000000000000000000000000000..b5319af5c70baa9ec547bdac130a98ff4ed4e6e4 --- /dev/null +++ b/benchmark/scripts/benchmark_multi_token_attention.py @@ -0,0 +1,218 @@ +import torch +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.transformers.multi_token_attention import LigerMultiTokenAttention +from liger_kernel.utils import infer_device + +device = infer_device() + + +class TorchMultiTokenAttention(torch.nn.Module): + def __init__(self, C_in, C_out, K, groups, bias, dtype, device): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(C_out, C_in // groups, K, K, dtype=dtype, device=device)) + self.bias = torch.nn.Parameter(torch.empty(C_out, dtype=dtype, device=device)) if bias else None + self.K = K + self.groups = groups + + def forward(self, scores): + B, C_in, L, _ = scores.shape + mask = torch.tril(torch.ones(L, L, dtype=torch.bool, device=scores.device)).view(1, 1, L, L) + inf = torch.tensor(-1e9, device=scores.device, dtype=scores.dtype) + zero = torch.tensor(0.0, device=scores.device, dtype=scores.dtype) + s_inf = scores.masked_fill(~mask, inf) + probs = torch.nn.functional.softmax(s_inf, dim=-1) + out_c = torch.nn.functional.conv2d( + probs, self.weight, self.bias, stride=1, padding=self.K // 2, groups=self.groups + ) + return out_c.masked_fill(~mask, zero) + + +def bench_speed_multi_token_attention(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + L = input.x + provider = input.kernel_provider + mode = input.kernel_operation_mode + + extra_benchmark_config = input.extra_benchmark_config + B = extra_benchmark_config["B"] + C_in = extra_benchmark_config["C_in"] + C_out = extra_benchmark_config["C_out"] + K = extra_benchmark_config["K"] + groups = extra_benchmark_config["groups"] + bias = extra_benchmark_config["bias"] + dtype = extra_benchmark_config["dtype"] + + x_shape = (B, C_in, L, L) + + triton_attn = ( + LigerMultiTokenAttention( + in_channels=C_in, + out_channels=C_out, + kernel_size=K, + stride=1, + padding=K // 2, + dilation=1, + groups=groups, + bias=bias, + ) + .to(device) + .to(dtype) + ) + + torch_attn = TorchMultiTokenAttention( + C_in=C_in, C_out=C_out, K=K, groups=groups, bias=bias, dtype=dtype, device=device + ) + + with torch.no_grad(): + torch_attn.weight.copy_(triton_attn.weight) + if bias: + torch_attn.bias.copy_(triton_attn.bias) + + x = torch.randn(x_shape, dtype=dtype, device=device) + dy = torch.randn_like(x) + x.requires_grad_(True) + + def fwd(): + if provider == "liger": + return triton_attn(x) + elif provider == "torch": + return torch_attn(x) + + print(f"Starting Warmup for input size: {x_shape}") + _ = fwd() + if mode in ("backward", "full"): + y = _ + y.backward(dy, retain_graph=True) + print("Done Warmup") + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, grad_to_none=[x], rep=100, quantiles=QUANTILES) + elif mode == "backward": + y = fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(dy, retain_graph=True), + grad_to_none=[x], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward(dy, retain_graph=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=[x], rep=100, quantiles=QUANTILES) + + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +def bench_memory_multi_token_attention(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + L = input.x + provider = input.kernel_provider + + extra_benchmark_config = input.extra_benchmark_config + B = extra_benchmark_config["B"] + C_in = extra_benchmark_config["C_in"] + C_out = extra_benchmark_config["C_out"] + K = extra_benchmark_config["K"] + groups = extra_benchmark_config["groups"] + bias = extra_benchmark_config["bias"] + dtype = extra_benchmark_config["dtype"] + + x_shape = (B, C_in, L, L) + + triton_attn = ( + LigerMultiTokenAttention( + in_channels=C_in, + out_channels=C_out, + kernel_size=K, + stride=1, + padding=K // 2, + dilation=1, + groups=groups, + bias=bias, + ) + .to(device) + .to(dtype) + ) + + torch_attn = TorchMultiTokenAttention( + C_in=C_in, C_out=C_out, K=K, groups=groups, bias=bias, dtype=dtype, device=device + ) + + with torch.no_grad(): + torch_attn.weight.copy_(triton_attn.weight) + if bias: + torch_attn.bias.copy_(triton_attn.bias) + + x = torch.randn(x_shape, dtype=dtype, device=device) + dy = torch.randn_like(x) + x.requires_grad_(True) + + def fwd(): + if provider == "liger": + return triton_attn(x) + elif provider == "torch": + return torch_attn(x) + + def full(): + y = fwd() + y.backward(dy, retain_graph=True) + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "multi_token_attention", + "x_name": "L", + "x_label": "sequence length", + "x_values": [2**i for i in range(5, 10)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "B": 2, + "C_in": 4, + "C_out": 4, + "K": 3, + "groups": 1, + "bias": True, + "dtype": torch.bfloat16, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_multi_token_attention, + kernel_operation_modes=["forward", "full", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_multi_token_attention, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_orpo_loss.py b/benchmark/scripts/benchmark_orpo_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..30b308c42605e08758165f4fb5e719f098995a5b --- /dev/null +++ b/benchmark/scripts/benchmark_orpo_loss.py @@ -0,0 +1,169 @@ +import os +import sys + +import torch +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.utils import infer_device + +device = infer_device() + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + + +############################################################################# +# Test the memory consumption of the linear fused cross entropy loss +############################################################################# + + +def bench_memory_fused_linear_orpo_loss( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + from test.chunked_loss.test_orpo_loss import LigerLMHeadORPO + from test.chunked_loss.test_orpo_loss import TorchLMHeadORPO + + B = input.x + T = input.extra_benchmark_config["T"] + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + provider = input.kernel_provider + + # Instantiate once and retrieve the first output only + torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device) + liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device) + torch_fwd = lambda x, target, nll_target: torch_lm_head_orpo(x, target, nll_target)[0] + liger_fwd = lambda x, target, nll_target: liger_lm_head_orpo(x, target, nll_target)[0] + + _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) + target = torch.randint(V, (B, T), dtype=torch.long, device=device) + nll_target = torch.randint(V, (B, T), dtype=torch.long, device=device) + + def fwd(): + if provider == "liger": + return liger_fwd(_input, target, nll_target) + elif provider == "huggingface": + return torch_fwd(_input, target, nll_target) + + def full(): + y = fwd() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +# ############################################################################# +# # Test the speed of the fused linear cross entropy loss +# ############################################################################# + + +def bench_speed_fused_linear_orpo_loss( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + from test.chunked_loss.test_orpo_loss import LigerLMHeadORPO + from test.chunked_loss.test_orpo_loss import TorchLMHeadORPO + + B = input.x + T = input.extra_benchmark_config["T"] + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + provider = input.kernel_provider + mode = input.kernel_operation_mode + + # Instantiate once and retrieve the first output only + torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device) + liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device) + torch_fwd = lambda x, target, nll_target: torch_lm_head_orpo(x, target, nll_target)[0] + liger_fwd = lambda x, target, nll_target: liger_lm_head_orpo(x, target, nll_target)[0] + + _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) + target = torch.randint(V, (B, T), dtype=torch.long, device=device) + nll_target = torch.randint(V, (B, T), dtype=torch.long, device=device) + + def fwd(): + if provider == "liger": + return liger_fwd(_input, target, nll_target) + elif provider == "huggingface": + return torch_fwd(_input, target, nll_target) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = fwd() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[_input], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + rep=100, + quantiles=QUANTILES, + ) + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "fused_linear_orpo_loss", + "x_name": "B", + "x_label": "B", + "x_values": [2**i for i in range(1, 5)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "T": 1024, + "H": 4096, + "V": 128256, + "mode": "forward", + "dtype": torch.bfloat16, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_fused_linear_orpo_loss, + kernel_operation_modes=["forward", "full", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_fused_linear_orpo_loss, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_poly_norm.py b/benchmark/scripts/benchmark_poly_norm.py new file mode 100755 index 0000000000000000000000000000000000000000..ddff431d74276d57376c7589435e4e6e9ea82419 --- /dev/null +++ b/benchmark/scripts/benchmark_poly_norm.py @@ -0,0 +1,197 @@ +import torch +import torch.nn as nn +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.transformers.poly_norm import LigerPolyNorm +from liger_kernel.utils import infer_device + +device = infer_device() + + +class NaivePolyNorm(nn.Module): + """ + Naive PyTorch implementation of PolyNorm. + + Reference: + https://github.com/BryceZhuo/PolyCom/ + + PolyNorm formula: + y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b + where norm(u) = u / sqrt(mean(u²) + ε) + """ + + def __init__(self, eps=1e-6): + super().__init__() + # Align with PolyCom reference: (1/3, 1/3, 1/3) and bias=1.0 + self.weight = nn.Parameter(torch.full((3,), 1.0 / 3.0)) + self.bias = nn.Parameter(torch.tensor(1.0)) + self.variance_epsilon = eps + + def _norm(self, x): + """RMSNorm operation""" + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon) + + def forward(self, hidden_states): + """ + Forward pass of PolyNorm + + Args: + hidden_states: input tensor of shape (..., H) + + Returns: + output tensor of same shape as input + """ + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + + # Compute powers + x_pow3 = hidden_states**3 + x_pow2 = hidden_states**2 + x_pow1 = hidden_states**1 + + # Normalize each power + norm_x3 = self._norm(x_pow3) + norm_x2 = self._norm(x_pow2) + norm_x1 = self._norm(x_pow1) + + # Weighted sum with bias + output = self.weight[0] * norm_x3 + self.weight[1] * norm_x2 + self.weight[2] * norm_x1 + self.bias + + return output.to(input_dtype) + + +def bench_speed_poly_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + N = input.x + provider = input.kernel_provider + mode = input.kernel_operation_mode + + extra_benchmark_config = input.extra_benchmark_config + M = extra_benchmark_config["M"] + eps = extra_benchmark_config["eps"] + dtype = extra_benchmark_config["dtype"] + + x_shape = (M, N) + + triton_poly = LigerPolyNorm(eps=eps).to(device) + naive_poly = NaivePolyNorm(eps=eps).to(device) + + x = torch.randn(x_shape, dtype=dtype, device=device) + dy = torch.randn_like(x) + x.requires_grad_(True) + + # utility functions + + def y_fwd(): + if provider == "liger": + return triton_poly(x) + + if provider == "huggingface": + return naive_poly(x) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + y_fwd, + grad_to_none=[x], + rep=500, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = y_fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(dy, retain_graph=True), + grad_to_none=[x], + rep=500, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = y_fwd() + y.backward(dy, retain_graph=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + grad_to_none=[x], + rep=500, + quantiles=QUANTILES, + ) + + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +def bench_memory_poly_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + N = input.x + provider = input.kernel_provider + + extra_benchmark_config = input.extra_benchmark_config + M = extra_benchmark_config["M"] + eps = extra_benchmark_config["eps"] + dtype = extra_benchmark_config["dtype"] + + x_shape = (M, N) + + triton_poly = LigerPolyNorm(eps=eps).to(device) + naive_poly = NaivePolyNorm(eps=eps).to(device) + + x = torch.randn(x_shape, dtype=dtype, device=device) + dy = torch.randn_like(x) + x.requires_grad_(True) + + # utility functions + def y_fwd(): + if provider == "liger": + return triton_poly(x) + if provider == "huggingface": + return naive_poly(x) + + def full(): + y = y_fwd() + y.backward(dy, retain_graph=True) + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "poly_norm", + "x_name": "H", + "x_label": "hidden size", + "x_values": [2**i for i in range(10, 16)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [{"M": 2048, "dtype": torch.bfloat16, "eps": 1e-6}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_poly_norm, + kernel_operation_modes=["forward", "full", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_poly_norm, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_qwen2vl_mrope.py b/benchmark/scripts/benchmark_qwen2vl_mrope.py new file mode 100755 index 0000000000000000000000000000000000000000..ec1c53b8909c11de94bcdc17b744dba310268472 --- /dev/null +++ b/benchmark/scripts/benchmark_qwen2vl_mrope.py @@ -0,0 +1,241 @@ +import torch +import triton + +from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLTextConfig +from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLRotaryEmbedding +from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb +from liger_kernel.utils import infer_device + +device = infer_device() + + +def bench_speed_qwen2vl_mrope( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + provider = input.kernel_provider + mode = input.kernel_operation_mode + + extra_benchmark_config = input.extra_benchmark_config + num_q_heads = extra_benchmark_config["num_q_heads"] + num_kv_heads = extra_benchmark_config["num_kv_heads"] + dtype = extra_benchmark_config["dtype"] + + # x can be either hidden_size or seq_len + hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x + seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x + + head_dim = hidden_size // num_q_heads + mrope_section_hw = head_dim * 3 // 16 + mrope_section = [ + head_dim // 2 - 2 * mrope_section_hw, + mrope_section_hw, + mrope_section_hw, + ] + config = Qwen2VLTextConfig( + hidden_size=hidden_size, + num_attention_heads=num_q_heads, + num_key_value_heads=num_kv_heads, + rope_theta=1000000.0, + mrope_section=mrope_section, + ) + rotary_emb = Qwen2VLRotaryEmbedding(config, device=device) + q = torch.randn( + (1, seq_len, num_q_heads, head_dim), + device=device, + requires_grad=True, + dtype=dtype, + ).transpose(1, 2) + k = torch.randn( + (1, seq_len, num_kv_heads, head_dim), + device=device, + requires_grad=True, + dtype=dtype, + ).transpose(1, 2) + dq, dk = ( + torch.randn_like(q, device=device, dtype=dtype), + torch.randn_like(k, device=device, dtype=dtype), + ) + pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1) + cos, sin = rotary_emb(k, pos_ids) + + def fwd(): + if provider == "liger": + return liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section) + elif provider == "huggingface": + return apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section) + else: + raise ValueError(f"Invalid provider: {provider} for M-RoPE embedding") + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + grad_to_none=[q, k], + rep=400, + quantiles=QUANTILES, + ) + elif mode == "backward": + q_out, k_out = fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True), + grad_to_none=[q, k], + rep=400, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + q_out, k_out = fwd() + torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + grad_to_none=[q, k], + rep=400, + quantiles=QUANTILES, + ) + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +def bench_memory_qwen2vl_mrope( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + provider = input.kernel_provider + + extra_benchmark_config = input.extra_benchmark_config + num_q_heads = extra_benchmark_config["num_q_heads"] + num_kv_heads = extra_benchmark_config["num_kv_heads"] + dtype = extra_benchmark_config["dtype"] + + # x can be either hidden_size or seq_len + hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x + seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x + + head_dim = hidden_size // num_q_heads + + mrope_section_hw = head_dim * 3 // 16 + mrope_section = [ + head_dim // 2 - 2 * mrope_section_hw, + mrope_section_hw, + mrope_section_hw, + ] + config = Qwen2VLTextConfig( + hidden_size=hidden_size, + num_attention_heads=num_q_heads, + num_key_value_heads=num_kv_heads, + rope_theta=1000000.0, + mrope_section=mrope_section, + ) + rotary_emb = Qwen2VLRotaryEmbedding(config, device=device) + q = torch.randn( + (1, seq_len, num_q_heads, head_dim), + device=device, + requires_grad=True, + dtype=dtype, + ).transpose(1, 2) + k = torch.randn( + (1, seq_len, num_kv_heads, head_dim), + device=device, + requires_grad=True, + dtype=dtype, + ).transpose(1, 2) + dq, dk = ( + torch.randn_like(q, device=device, dtype=dtype), + torch.randn_like(k, device=device, dtype=dtype), + ) + pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1) + cos, sin = rotary_emb(k, pos_ids) + + def full(): + if provider == "liger": + q_out, k_out = liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section) + else: + q_out, k_out = apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section) + torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True) + + mem_50, mem_20, mem_80 = _test_memory( + full, + quantiles=QUANTILES, + ) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs_varying_hidden_size = { + "kernel_name": "qwen2vl_mrope", + "x_name": "H", + "x_label": "hidden size", + "x_values": [32 * (2**i) for i in range(4, 10, 2)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "dtype": torch.bfloat16, + "seq_len": 2048, + "num_q_heads": 32, + "num_kv_heads": 8, + } + ], + "overwrite": args.overwrite, + } + run_benchmarks( + bench_test_fn=bench_speed_qwen2vl_mrope, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs_varying_hidden_size, + ) + run_benchmarks( + bench_test_fn=bench_memory_qwen2vl_mrope, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs_varying_hidden_size, + ) + + common_configs_varying_seq_len = { + "kernel_name": "qwen2vl_mrope", + "x_name": "T", + "x_label": "sequence length", + "x_values": [2**i for i in range(10, 15)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "dtype": torch.bfloat16, + "hidden_size": 8192, + "num_q_heads": 32, + "num_kv_heads": 8, + } + ], + "overwrite": args.overwrite, + } + run_benchmarks( + bench_test_fn=bench_speed_qwen2vl_mrope, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs_varying_seq_len, + ) + run_benchmarks( + bench_test_fn=bench_memory_qwen2vl_mrope, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs_varying_seq_len, + ) diff --git a/benchmark/scripts/benchmark_rms_norm.py b/benchmark/scripts/benchmark_rms_norm.py new file mode 100755 index 0000000000000000000000000000000000000000..6bcd56a8378a2f6d21dfecf5157569f3d59af9b8 --- /dev/null +++ b/benchmark/scripts/benchmark_rms_norm.py @@ -0,0 +1,162 @@ +import torch +import torch.nn as nn +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.transformers.rms_norm import LigerRMSNorm +from liger_kernel.utils import infer_device + +device = infer_device() + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +def bench_speed_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + N = input.x + provider = input.kernel_provider + mode = input.kernel_operation_mode + + extra_benchmark_config = input.extra_benchmark_config + M = extra_benchmark_config["M"] + eps = extra_benchmark_config["eps"] + dtype = extra_benchmark_config["dtype"] + + x_shape = (M, N) + + triton_rms = LigerRMSNorm(hidden_size=N, eps=eps).to(device) + llama_rms = LlamaRMSNorm(hidden_size=N, eps=eps).to(device) + + x = torch.randn(x_shape, dtype=dtype, device=device) + dy = torch.randn_like(x) + x.requires_grad_(True) + + # utility functions + + def y_fwd(): + if provider == "liger": + return triton_rms(x) + + if provider == "huggingface": + return llama_rms(x) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + y_fwd, + grad_to_none=[x], + rep=500, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = y_fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(dy, retain_graph=True), + grad_to_none=[x], + rep=500, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = y_fwd() + y.backward(dy, retain_graph=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + grad_to_none=[x], + rep=500, + quantiles=QUANTILES, + ) + + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +def bench_memory_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + N = input.x + provider = input.kernel_provider + + extra_benchmark_config = input.extra_benchmark_config + M = extra_benchmark_config["M"] + eps = extra_benchmark_config["eps"] + dtype = extra_benchmark_config["dtype"] + + x_shape = (M, N) + + triton_rms = LigerRMSNorm(hidden_size=N, eps=eps).to(device) + llama_rms = LlamaRMSNorm(hidden_size=N, eps=eps).to(device) + + x = torch.randn(x_shape, dtype=dtype, device=device) + dy = torch.randn_like(x) + x.requires_grad_(True) + + # utility functions + def y_fwd(): + if provider == "liger": + return triton_rms(x) + if provider == "huggingface": + return llama_rms(x) + + def full(): + y = y_fwd() + y.backward(dy, retain_graph=True) + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "rms_norm", + "x_name": "H", + "x_label": "hidden size", + "x_values": [2**i for i in range(10, 16)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [{"M": 2048, "dtype": torch.bfloat16, "eps": 1e-6}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_rms_norm, + kernel_operation_modes=["forward", "full", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_rms_norm, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_rope.py b/benchmark/scripts/benchmark_rope.py new file mode 100755 index 0000000000000000000000000000000000000000..ac792881d41965e1ba4fc2be51cd2545159c646e --- /dev/null +++ b/benchmark/scripts/benchmark_rope.py @@ -0,0 +1,223 @@ +import torch +import triton + +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.transformers.rope import liger_rotary_pos_emb +from liger_kernel.utils import infer_device +from liger_kernel.utils import transformers_version_dispatch + +device = infer_device() + + +def bench_speed_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + provider = input.kernel_provider + mode = input.kernel_operation_mode + + extra_benchmark_config = input.extra_benchmark_config + num_q_heads = extra_benchmark_config["num_q_heads"] + num_kv_heads = extra_benchmark_config["num_kv_heads"] + dtype = extra_benchmark_config["dtype"] + + # x can be either hidden_size or seq_len + hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x + seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x + + head_dim = hidden_size // num_q_heads + rotary_emb = transformers_version_dispatch( + "4.48.0", + LlamaRotaryEmbedding, + LlamaRotaryEmbedding, + before_kwargs={"dim": head_dim, "device": device}, + after_kwargs={"config": LlamaConfig(num_kv_heads=num_kv_heads, head_dim=head_dim), "device": device}, + ) + q = torch.randn( + (1, seq_len, num_q_heads, head_dim), + device=device, + requires_grad=True, + dtype=dtype, + ).transpose(1, 2) + k = torch.randn( + (1, seq_len, num_kv_heads, head_dim), + device=device, + requires_grad=True, + dtype=dtype, + ).transpose(1, 2) + dq, dk = ( + torch.randn_like(q, device=device, dtype=dtype), + torch.randn_like(k, device=device), + ) + pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) + cos, sin = rotary_emb(k, pos_ids) + + def fwd(): + if provider == "liger": + return liger_rotary_pos_emb(q, k, cos, sin, pos_ids) + elif provider == "huggingface": + return apply_rotary_pos_emb(q, k, cos, sin, pos_ids) + else: + raise ValueError(f"Invalid provider: {provider} for RoPE embedding") + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + grad_to_none=[q, k], + rep=400, + quantiles=QUANTILES, + ) + elif mode == "backward": + q_out, k_out = fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True), + grad_to_none=[q, k], + rep=400, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + q_out, k_out = fwd() + torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + grad_to_none=[q, k], + rep=400, + quantiles=QUANTILES, + ) + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +def bench_memory_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + provider = input.kernel_provider + + extra_benchmark_config = input.extra_benchmark_config + num_q_heads = extra_benchmark_config["num_q_heads"] + num_kv_heads = extra_benchmark_config["num_kv_heads"] + dtype = extra_benchmark_config["dtype"] + + # x can be either hidden_size or seq_len + hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x + seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x + + head_dim = hidden_size // num_q_heads + rotary_emb = transformers_version_dispatch( + "4.48.0", + LlamaRotaryEmbedding, + LlamaRotaryEmbedding, + before_kwargs={"dim": head_dim, "device": device}, + after_kwargs={"config": LlamaConfig(num_kv_heads=num_kv_heads, head_dim=head_dim), "device": device}, + ) + q = torch.randn( + (1, seq_len, num_q_heads, head_dim), + device=device, + requires_grad=True, + dtype=dtype, + ).transpose(1, 2) + k = torch.randn( + (1, seq_len, num_kv_heads, head_dim), + device=device, + requires_grad=True, + dtype=dtype, + ).transpose(1, 2) + dq, dk = ( + torch.randn_like(q, device=device, dtype=dtype), + torch.randn_like(k, device=device), + ) + pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) + cos, sin = rotary_emb(k, pos_ids) + + def full(): + if provider == "liger": + q_out, k_out = liger_rotary_pos_emb(q, k, cos, sin, pos_ids) + else: + q_out, k_out = apply_rotary_pos_emb(q, k, cos, sin, pos_ids) + torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True) + + mem_50, mem_20, mem_80 = _test_memory( + full, + quantiles=QUANTILES, + ) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs_varying_hidden_size = { + "kernel_name": "rope", + "x_name": "H", + "x_label": "hidden size", + "x_values": [32 * (2**i) for i in range(4, 10, 2)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "dtype": torch.bfloat16, + "seq_len": 2048, + "num_q_heads": 32, + "num_kv_heads": 8, + } + ], + "overwrite": args.overwrite, + } + run_benchmarks( + bench_test_fn=bench_speed_rope, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs_varying_hidden_size, + ) + run_benchmarks( + bench_test_fn=bench_memory_rope, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs_varying_hidden_size, + ) + + common_configs_varying_seq_len = { + "kernel_name": "rope", + "x_name": "T", + "x_label": "sequence length", + "x_values": [2**i for i in range(10, 15)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "dtype": torch.bfloat16, + "hidden_size": 8192, + "num_q_heads": 32, + "num_kv_heads": 8, + } + ], + "overwrite": args.overwrite, + } + run_benchmarks( + bench_test_fn=bench_speed_rope, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs_varying_seq_len, + ) + run_benchmarks( + bench_test_fn=bench_memory_rope, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs_varying_seq_len, + ) diff --git a/benchmark/scripts/benchmark_simpo_loss.py b/benchmark/scripts/benchmark_simpo_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..148b8e3e4db0d68cc524f443e3076d7f1afd596c --- /dev/null +++ b/benchmark/scripts/benchmark_simpo_loss.py @@ -0,0 +1,167 @@ +import os +import sys + +import torch +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.utils import infer_device + +device = infer_device() + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + + +############################################################################# +# Test the memory consumption of the linear fused cross entropy loss +############################################################################# + + +def bench_memory_fused_linear_simpo_loss( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + from test.chunked_loss.test_simpo_loss import LigerLMHeadSimPO + from test.chunked_loss.test_simpo_loss import TorchLMHeadCPO + + B = input.x + T = input.extra_benchmark_config["T"] + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + provider = input.kernel_provider + + # Instantiate once and retrieve the first output only + torch_lm_head_simpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device) + liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) + torch_fwd = lambda x, target: torch_lm_head_simpo(x, target)[0] + liger_fwd = lambda x, target: liger_lm_head_simpo(x, target)[0] + + _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) + target = torch.randint(V, (B, T), dtype=torch.long, device=device) + + def fwd(): + if provider == "liger": + return liger_fwd(_input, target) + elif provider == "huggingface": + return torch_fwd(_input, target) + + def full(): + y = fwd() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +# ############################################################################# +# # Test the speed of the fused linear cross entropy loss +# ############################################################################# + + +def bench_speed_fused_linear_simpo_loss( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + from test.chunked_loss.test_simpo_loss import LigerLMHeadSimPO + from test.chunked_loss.test_simpo_loss import TorchLMHeadCPO + + B = input.x + T = input.extra_benchmark_config["T"] + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + provider = input.kernel_provider + mode = input.kernel_operation_mode + + # Instantiate once and retrieve the first output only + torch_lm_head_simpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device) + liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) + torch_fwd = lambda x, target: torch_lm_head_simpo(x, target)[0] + liger_fwd = lambda x, target: liger_lm_head_simpo(x, target)[0] + + _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) + target = torch.randint(V, (B, T), dtype=torch.long, device=device) + + def fwd(): + if provider == "liger": + return liger_fwd(_input, target) + elif provider == "huggingface": + return torch_fwd(_input, target) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = fwd() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[_input], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + rep=100, + quantiles=QUANTILES, + ) + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "fused_linear_simpo_loss", + "x_name": "B", + "x_label": "B", + "x_values": [2**i for i in range(1, 5)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "T": 1024, + "H": 4096, + "V": 128256, + "mode": "forward", + "dtype": torch.bfloat16, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_fused_linear_simpo_loss, + kernel_operation_modes=["forward", "full", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_fused_linear_simpo_loss, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_softmax.py b/benchmark/scripts/benchmark_softmax.py new file mode 100755 index 0000000000000000000000000000000000000000..10e994c8c12ab90b0accfcbee81eee9961d84ab2 --- /dev/null +++ b/benchmark/scripts/benchmark_softmax.py @@ -0,0 +1,140 @@ +import torch +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.transformers.softmax import LigerSoftmax +from liger_kernel.utils import infer_device + +device = infer_device() + + +def bench_speed_softmax(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + N = input.x + provider = input.kernel_provider + mode = input.kernel_operation_mode + extra_benchmark_config = input.extra_benchmark_config + M = extra_benchmark_config["M"] + dtype = extra_benchmark_config["dtype"] + + x_shape = (M, N) + liger_softmax = LigerSoftmax().to(device).to(dtype) + torch_softmax = torch.nn.Softmax(dim=-1).to(device).to(dtype) + + x = torch.randn(x_shape, dtype=dtype, device=device) + dy = torch.randn_like(x) + x.requires_grad_(True) + + def y_fwd(): + if provider == "liger": + return liger_softmax(x) + if provider == "torch": + return torch_softmax(x) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500) + elif mode == "backward": + y = y_fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(dy, retain_graph=True), + quantiles=QUANTILES, + grad_to_none=[x], + rep=500, + ) + elif mode == "full": + + def full(): + y = y_fwd() + y.backward(dy, retain_graph=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=[x], rep=500) + + if any(val is None for val in (ms_20, ms_50, ms_80)): + raise RuntimeError(f"Benchmark speed result is None: ms_20={ms_20}, ms_50={ms_50}, ms_80={ms_80}") + + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +def bench_memory_softmax(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + shape = input.x + provider = input.kernel_provider + mode = input.kernel_operation_mode + extra_benchmark_config = input.extra_benchmark_config + dtype = extra_benchmark_config.get("dtype", torch.float32) + + torch_softmax = torch.nn.Softmax(dim=-1) + liger_softmax = LigerSoftmax().to(device).to(dtype) + + x = torch.randn(shape, device=device, dtype=dtype, requires_grad=True) + + def fwd(): + if provider == "liger": + return liger_softmax(x) + elif provider == "torch": + return torch_softmax(x) + else: + raise ValueError(f"Invalid provider: {provider} for softmax") + + def full(): + y = fwd() + y.backward(torch.ones_like(y), retain_graph=True) + + if mode == "forward": + mem_50, mem_20, mem_80 = _test_memory(fwd, quantiles=QUANTILES) + elif mode == "backward": + do = torch.ones_like(x) + y = fwd() + mem_50, mem_20, mem_80 = _test_memory(lambda: y.backward(do, retain_graph=True), quantiles=QUANTILES) + else: + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + + if any(val is None for val in (mem_20, mem_50, mem_80)): + raise RuntimeError(f"Benchmark memory result is None: mem_20={mem_20}, mem_50={mem_50}, mem_80={mem_80}") + + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = dict( + kernel_name="softmax", + x_name="N", + x_label="hidden size", + x_values=[128, 256, 512, 1024, 2048, 4096], + kernel_providers=["liger", "torch"], + extra_benchmark_configs=[ + {"M": 2048, "dtype": torch.float32}, + {"M": 2048, "dtype": torch.bfloat16}, + ], + ) + + run_benchmarks( + bench_test_fn=bench_speed_softmax, + kernel_operation_modes=["forward", "full", "backward"], + metric_name="speed", + metric_unit="ms", + overwrite=args.overwrite, + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_softmax, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + overwrite=args.overwrite, + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_sparse_multi_token_attention.py b/benchmark/scripts/benchmark_sparse_multi_token_attention.py new file mode 100755 index 0000000000000000000000000000000000000000..98f47d713920e6842c3be9761672296066d9d644 --- /dev/null +++ b/benchmark/scripts/benchmark_sparse_multi_token_attention.py @@ -0,0 +1,254 @@ +import torch +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.transformers.multi_token_attention import LigerMultiTokenAttention +from liger_kernel.utils import infer_device + +device = infer_device() + + +class TorchSparseMultiTokenAttention(torch.nn.Module): + def __init__(self, C_in, C_out, K, groups, bias, dtype, device): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(C_out, C_in // groups, K, K, dtype=dtype, device=device)) + self.bias = torch.nn.Parameter(torch.empty(C_out, dtype=dtype, device=device)) if bias else None + self.K = K + self.groups = groups + self.dtype = dtype + self.compute_dtype = torch.float32 + + def forward(self, scores): + B, C_in, L, _ = scores.shape + mask = torch.tril(torch.ones(L, L, dtype=torch.bool, device=scores.device)).view(1, 1, L, L) + inf = torch.tensor(-1e9, device=scores.device, dtype=self.compute_dtype) + zero = torch.tensor(0.0, device=scores.device, dtype=self.compute_dtype) + + s_compute = scores.to(self.compute_dtype) + s_inf = s_compute.masked_fill(~mask, inf) + + dim = -1 + z = s_inf + + z_sorted, _ = torch.sort(z, dim=dim, descending=True) + + cum_sum = torch.cumsum(z_sorted, dim=dim) + + k_indices = torch.arange(1, L + 1, device=z.device, dtype=z.dtype).view(1, 1, 1, L) + + is_positive = z_sorted > -1e8 + condition = (1 + k_indices * z_sorted > cum_sum) & is_positive + k_sparsemax = torch.sum(condition, dim=dim, keepdim=True) + + k_sparsemax_safe = torch.max(k_sparsemax, torch.ones_like(k_sparsemax)) + + cum_sum_k = torch.gather(cum_sum, dim=dim, index=k_sparsemax_safe.long() - 1) + + tau = (cum_sum_k - 1) / k_sparsemax_safe.to(z.dtype) + tau = torch.where(k_sparsemax == 0, torch.full_like(tau, float("inf")), tau) + + probs = torch.clamp(z - tau, min=0) + + weight_compute = self.weight.to(self.compute_dtype) + bias_compute = self.bias.to(self.compute_dtype) if self.bias is not None else None + + out_c = torch.nn.functional.conv2d( + probs, weight_compute, bias_compute, stride=1, padding=self.K // 2, groups=self.groups + ) + return out_c.masked_fill(~mask, zero).to(scores.dtype) + + +def bench_speed_sparse_multi_token_attention(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + L = input.x + provider = input.kernel_provider + mode = input.kernel_operation_mode + + extra_benchmark_config = input.extra_benchmark_config + B = extra_benchmark_config["B"] + C_in = extra_benchmark_config["C_in"] + C_out = extra_benchmark_config["C_out"] + K = extra_benchmark_config["K"] + groups = extra_benchmark_config["groups"] + bias = extra_benchmark_config["bias"] + dtype = extra_benchmark_config["dtype"] + + x_shape = (B, C_in, L, L) + + liger_attn = ( + LigerMultiTokenAttention( + in_channels=C_in, + out_channels=C_out, + kernel_size=K, + stride=1, + padding=K // 2, + dilation=1, + groups=groups, + bias=bias, + sparse=True, + ) + .to(device) + .to(dtype) + ) + + torch_attn = TorchSparseMultiTokenAttention( + C_in=C_in, C_out=C_out, K=K, groups=groups, bias=bias, dtype=dtype, device=device + ) + + with torch.no_grad(): + torch.nn.init.kaiming_uniform_(liger_attn.weight, a=5**0.5) + if bias: + torch.nn.init.zeros_(liger_attn.bias) + torch_attn.weight.copy_(liger_attn.weight) + if bias: + torch_attn.bias.copy_(liger_attn.bias) + + x = torch.randn(x_shape, dtype=dtype, device=device) + dy = torch.randn_like(x) + x.requires_grad_(True) + + def fwd(): + if provider == "liger": + return liger_attn(x) + elif provider == "torch": + return torch_attn(x) + + print(f"Starting Warmup for input size: {x_shape}") + _ = fwd() + if mode in ("backward", "full"): + y = _ + y.backward(dy, retain_graph=True) + print("Done Warmup") + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, grad_to_none=[x], rep=100, quantiles=QUANTILES) + elif mode == "backward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(dy, retain_graph=True), + grad_to_none=[x], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward(dy, retain_graph=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, grad_to_none=[x], rep=100, quantiles=QUANTILES) + + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +def bench_memory_sparse_multi_token_attention(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + L = input.x + provider = input.kernel_provider + + extra_benchmark_config = input.extra_benchmark_config + B = extra_benchmark_config["B"] + C_in = extra_benchmark_config["C_in"] + C_out = extra_benchmark_config["C_out"] + K = extra_benchmark_config["K"] + groups = extra_benchmark_config["groups"] + bias = extra_benchmark_config["bias"] + dtype = extra_benchmark_config["dtype"] + + x_shape = (B, C_in, L, L) + + liger_attn = ( + LigerMultiTokenAttention( + in_channels=C_in, + out_channels=C_out, + kernel_size=K, + stride=1, + padding=K // 2, + dilation=1, + groups=groups, + bias=bias, + sparse=True, + ) + .to(device) + .to(dtype) + ) + + torch_attn = TorchSparseMultiTokenAttention( + C_in=C_in, C_out=C_out, K=K, groups=groups, bias=bias, dtype=dtype, device=device + ) + + with torch.no_grad(): + torch.nn.init.kaiming_uniform_(liger_attn.weight, a=5**0.5) + if bias: + torch.nn.init.zeros_(liger_attn.bias) + torch_attn.weight.copy_(liger_attn.weight) + if bias: + torch_attn.bias.copy_(liger_attn.bias) + + x = torch.randn(x_shape, dtype=dtype, device=device) + dy = torch.randn_like(x) + x.requires_grad_(True) + + def fwd(): + if provider == "liger": + return liger_attn(x) + elif provider == "torch": + return torch_attn(x) + + def full(): + y = fwd() + y.backward(dy, retain_graph=True) + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "sparse_multi_token_attention", + "x_name": "L", + "x_label": "sequence length", + "x_values": [2**i for i in range(5, 10)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "B": 2, + "C_in": 4, + "C_out": 4, + "K": 3, + "groups": 1, + "bias": True, + "dtype": torch.float32, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_sparse_multi_token_attention, + kernel_operation_modes=["forward", "full", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_sparse_multi_token_attention, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_sparsemax.py b/benchmark/scripts/benchmark_sparsemax.py new file mode 100755 index 0000000000000000000000000000000000000000..919f4c66defbe1d2388aa17c912fec09058316d4 --- /dev/null +++ b/benchmark/scripts/benchmark_sparsemax.py @@ -0,0 +1,172 @@ +import torch +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.transformers.sparsemax import LigerSparsemax +from liger_kernel.utils import infer_device + +device = infer_device() + + +def torch_sparsemax(input_tensor: torch.Tensor, dim: int = -1) -> torch.Tensor: + input_dims = input_tensor.dim() + if dim < 0: + dim = input_dims + dim + input_sorted, _ = torch.sort(input_tensor, dim=dim, descending=True) + cumsum_input = torch.cumsum(input_sorted, dim=dim) + input_size = input_tensor.size(dim) + range_tensor = torch.arange(1, input_size + 1, device=input_tensor.device, dtype=input_tensor.dtype) + shape = [1] * input_dims + shape[dim] = input_size + range_tensor = range_tensor.view(shape) + k_bound = 1 + range_tensor * input_sorted + support = k_bound > cumsum_input + k = support.sum(dim=dim, keepdim=True).clamp(min=1) + support_sum = (input_sorted * support).sum(dim=dim, keepdim=True) + tau = (support_sum - 1) / k + return torch.clamp(input_tensor - tau, min=0) + + +class TorchSparsemax(torch.nn.Module): + def __init__(self, dim: int = -1): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch_sparsemax(x, dim=self.dim) + + +def bench_speed_sparsemax(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + V = input.x + provider = input.kernel_provider + mode = input.kernel_operation_mode + + extra_benchmark_config = input.extra_benchmark_config + B = extra_benchmark_config["B"] + T = extra_benchmark_config["T"] + dim = extra_benchmark_config["dim"] + dtype = extra_benchmark_config["dtype"] + + x_shape = (B * T, V) + + torch_sparsemax_module = TorchSparsemax(dim=dim).to(device) + liger_sparsemax_module = LigerSparsemax(dim=dim).to(device) + + x = torch.randn(x_shape, dtype=dtype, device=device) + dy = torch.randn_like(x) + x.requires_grad_(True) + + # utility functions + def y_fwd(): + if provider == "liger": + return liger_sparsemax_module(x) + elif provider == "torch": + return torch_sparsemax_module(x) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + y_fwd, + grad_to_none=[x], + rep=500, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = y_fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(dy, retain_graph=True), + grad_to_none=[x], + rep=500, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = y_fwd() + y.backward(dy, retain_graph=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + grad_to_none=[x], + rep=500, + quantiles=QUANTILES, + ) + + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +def bench_memory_sparsemax(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + V = input.x + provider = input.kernel_provider + + extra_benchmark_config = input.extra_benchmark_config + B = extra_benchmark_config["B"] + T = extra_benchmark_config["T"] + dim = extra_benchmark_config["dim"] + dtype = extra_benchmark_config["dtype"] + + x_shape = (B * T, V) + + torch_sparsemax_module = TorchSparsemax(dim=dim).to(device) + liger_sparsemax_module = LigerSparsemax(dim=dim).to(device) + + x = torch.randn(x_shape, dtype=dtype, device=device) + dy = torch.randn_like(x) + x.requires_grad_(True) + + # utility functions + def y_fwd(): + if provider == "liger": + return liger_sparsemax_module(x) + elif provider == "torch": + return torch_sparsemax_module(x) + + def full(): + y = y_fwd() + y.backward(dy, retain_graph=True) + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "sparsemax", + "x_name": "V", + "x_label": "feature size", + "x_values": [2**i for i in range(10, 16)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [{"B": 4, "T": 512, "dim": -1, "dtype": torch.float32}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_sparsemax, + kernel_operation_modes=["forward", "full", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_sparsemax, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_swiglu.py b/benchmark/scripts/benchmark_swiglu.py new file mode 100755 index 0000000000000000000000000000000000000000..8d46572fdb1ccbd0e14daa4be1e4da3037575ca8 --- /dev/null +++ b/benchmark/scripts/benchmark_swiglu.py @@ -0,0 +1,115 @@ +import math + +import torch + +from benchmark_model_configs import compute_seq_len_sweep_config +from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import get_benchmark_model_config +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaMLP +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import parse_benchmark_script_args +from utils import run_benchmarks +from utils import run_memory_benchmark +from utils import run_speed_benchmark + +from liger_kernel.transformers.swiglu import LigerSwiGLUMLP +from liger_kernel.utils import infer_device + +device = infer_device() + + +def _setup_swiglu(input: SingleBenchmarkRunInput): + """Create input tensor and SwiGLU layer from benchmark config.""" + cfg = input.extra_benchmark_config + llama_config = LlamaConfig( + hidden_size=cfg["hidden_size"], + intermediate_size=cfg["intermediate_size"], + hidden_act=cfg["hidden_act"], + ) + x = torch.randn( + cfg["bsz"], + input.x, + cfg["hidden_size"], + device=device, + dtype=cfg["dtype"], + requires_grad=True, + ) + if input.kernel_provider == "liger": + layer = LigerSwiGLUMLP(config=llama_config).to(device).to(cfg["dtype"]) + elif input.kernel_provider == "huggingface": + layer = LlamaMLP(config=llama_config).to(device).to(cfg["dtype"]) + else: + raise ValueError(f"Invalid provider: {input.kernel_provider} for SwiGLU") + return x, layer + + +def bench_speed_swiglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, layer = _setup_swiglu(input) + return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x]) + + +def bench_memory_swiglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, layer = _setup_swiglu(input) + return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + model = get_benchmark_model_config(args.model) + probe_seq_len = 1024 + + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_seq_len, + kernel_provider="huggingface", + extra_benchmark_config={ + "bsz": 1, + "hidden_size": model.hidden_size, + "intermediate_size": model.intermediate_size, + "hidden_act": "silu", + "dtype": model.dtype, + }, + ) + x, layer = _setup_swiglu(probe_input) + return layer(x) + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_seq_len + + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "swiglu", + "x_name": "T", + "x_label": "sequence length", + "x_values": [2**i for i in range(10, int(math.log2(config.seq_len)) + 1)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "bsz": config.batch_size, + "hidden_size": model.hidden_size, + "intermediate_size": model.intermediate_size, + "hidden_act": "silu", + "dtype": model.dtype, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_swiglu, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_swiglu, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_tiled_mlp.py b/benchmark/scripts/benchmark_tiled_mlp.py new file mode 100755 index 0000000000000000000000000000000000000000..1eaf21dac9f5c14aa46332ab832109ca37809957 --- /dev/null +++ b/benchmark/scripts/benchmark_tiled_mlp.py @@ -0,0 +1,397 @@ +import math + +import torch +import torch.nn as nn +import triton + +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaMLP +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.transformers.geglu import LigerGEGLUMLP +from liger_kernel.transformers.swiglu import LigerSwiGLUMLP +from liger_kernel.transformers.tiled_mlp import LigerTiledGEGLUMLP +from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP +from liger_kernel.utils import infer_device + +device = infer_device() + + +# DeepSpeed TiledMLP implementation +# Based on: https://github.com/deepspeedai/DeepSpeed/blob/v0.18.2/deepspeed/runtime/sequence_parallel/ulysses_sp.py#L838 +class DeepSpeedTiledMLP(torch.autograd.Function): + """ + DeepSpeed's TiledMLP implementation for fair comparison. + This is the actual DeepSpeed algorithm that performs tiled MLP computation + to massively reduce memory usage with very long sequence lengths. + + This module re-computes forward in the backward, so forward occurs twice per iteration. + """ + + @staticmethod + def forward(ctx, fn, self, x, shards, compute_params) -> torch.Tensor: + ctx.fn = fn + ctx.self = self + ctx.shards = shards + ctx.compute_params = [p for p in compute_params if p.requires_grad] if compute_params else [] + ctx.save_for_backward(x) + + # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts) + x_shards = list(torch.chunk(x, chunks=shards, dim=-2)) + with torch.no_grad(): + output_shards = [fn(self, x_shard) for x_shard in x_shards] + output_unsharded = torch.cat(output_shards, dim=-2) + + return output_unsharded + + @staticmethod + def backward(ctx, *grads): + fn = ctx.fn + (x,) = ctx.saved_tensors + self = ctx.self + shards = ctx.shards + compute_params = ctx.compute_params + + x_requires_grad = x.requires_grad + x = x.detach() + # detach() unsets x.requires_grad, so restore it + x.requires_grad_(x_requires_grad) + + # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts) + hidden_size = x.shape[-1] + x_shape_orig = x.shape + + # flatten bs+seqlen to avoid having stride issues when narrowing into seqlen w/ bs>1 + x = x.view(-1, hidden_size) + incoming_grad = grads[0].view(-1, hidden_size) + x_grad = torch.zeros_like(x) + + x_shards = list(torch.chunk(x, chunks=shards, dim=0)) + + for i, x_shard in enumerate(x_shards): + # Tell deepspeed not to add a new grad to its ipg bucket until the last shard is run + # XXX: DDP, FSDP will need something similar to make it work + if compute_params: + if i + 1 < shards: + for param in compute_params: + if hasattr(param, "ds_grad_is_ready"): + param.ds_grad_is_ready = False + else: + # last shard, can add the grad + for param in compute_params: + if hasattr(param, "ds_grad_is_ready"): + param.ds_grad_is_ready = True + + x_shard.requires_grad_(x_requires_grad) + + # if seqlen is not exactly divisible by shards the last step will be shorter than shard_step + shard_step = x_shards[i].shape[0] + shard_offset = i * x_shards[0].shape[0] + + x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard) + incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step).view_as(x_shard) + with torch.enable_grad(): + output = fn(self, x_shard) + torch.autograd.backward(output, incoming_grad_shard) + + # unflatten + x_grad = x_grad.view(x_shape_orig) + + return (None, None, x_grad, None, None) + + +# DeepSpeed TiledMLP wrapper to match our interface +class DeepSpeedTiledMLPWrapper(nn.Module): + """ + Wrapper for DeepSpeed's TiledMLP to match the interface used in benchmarks. + Uses the DeepSpeed TiledMLP algorithm for memory-efficient MLP computation. + """ + + def __init__(self, config, num_shards=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.num_shards = num_shards + + self.mlp = LlamaMLP(config=config) + + def forward(self, x): + # Calculate num_shards if not provided + num_shards = self.num_shards + if num_shards is None: + hidden_size = x.shape[-1] + seqlen = x.shape[-2] + num_shards = math.ceil(seqlen / hidden_size) + num_shards = max(1, num_shards) + + # Collect compute parameters for DeepSpeed ZeRO compatibility + compute_params = [ + self.mlp.down_proj.weight, + self.mlp.gate_proj.weight, + self.mlp.up_proj.weight, + ] + + # Define the MLP forward function for DeepSpeed TiledMLP + def mlp_forward(mlp_module, x_input): + return mlp_module.down_proj(mlp_module.act_fn(mlp_module.gate_proj(x_input)) * mlp_module.up_proj(x_input)) + + # Use DeepSpeed's TiledMLP implementation + return DeepSpeedTiledMLP.apply( + mlp_forward, + self.mlp, + x, + num_shards, + compute_params, + ) + + +def bench_speed_tiled_mlp(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + seq_len = input.x + bsz = input.extra_benchmark_config["bsz"] + hidden_size = input.extra_benchmark_config["hidden_size"] + intermediate_size = input.extra_benchmark_config["intermediate_size"] + hidden_act = input.extra_benchmark_config["hidden_act"] + dtype = input.extra_benchmark_config["dtype"] + num_shards = input.extra_benchmark_config.get("num_shards", None) + activation_type = input.extra_benchmark_config["activation_type"] + provider = input.kernel_provider + mode = input.kernel_operation_mode + + llama_config = LlamaConfig( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + hidden_act=hidden_act, + ) + + x_shape = (bsz, seq_len, hidden_size) + + # initialize input + x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True) + + if activation_type == "geglu": + if provider == "huggingface": + layer = LlamaMLP(config=llama_config).to(device).to(dtype) + elif provider == "liger": + layer = LigerGEGLUMLP(config=llama_config).to(device).to(dtype) + elif provider == "liger_tiled": + layer = LigerTiledGEGLUMLP(config=llama_config, num_shards=num_shards).to(device).to(dtype) + elif provider == "deepspeed_tiled": + layer = DeepSpeedTiledMLPWrapper(config=llama_config, num_shards=num_shards).to(device).to(dtype) + else: + raise ValueError(f"Invalid provider: {provider} for GEGLU") + elif activation_type == "swiglu": + if provider == "huggingface": + layer = LlamaMLP(config=llama_config).to(device).to(dtype) + elif provider == "liger": + layer = LigerSwiGLUMLP(config=llama_config).to(device).to(dtype) + elif provider == "liger_tiled": + layer = LigerTiledSwiGLUMLP(config=llama_config, num_shards=num_shards).to(device).to(dtype) + elif provider == "deepspeed_tiled": + layer = DeepSpeedTiledMLPWrapper(config=llama_config, num_shards=num_shards).to(device).to(dtype) + else: + raise ValueError(f"Invalid provider: {provider} for SwiGLU") + else: + raise ValueError(f"Invalid activation_type: {activation_type}") + + def fwd(): + return layer(x) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + grad_to_none=[x], + rep=10, + quantiles=QUANTILES, + ) + elif mode == "backward": + do = torch.randn_like(x) + y = fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(do, retain_graph=True), + grad_to_none=[x], + rep=10, + quantiles=QUANTILES, + ) + else: + + def full(): + y = fwd() + y.backward(torch.randn_like(y), retain_graph=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + grad_to_none=[x], + rep=10, + quantiles=QUANTILES, + ) + + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +def bench_memory_tiled_mlp(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + seq_len = input.x + bsz = input.extra_benchmark_config["bsz"] + hidden_size = input.extra_benchmark_config["hidden_size"] + intermediate_size = input.extra_benchmark_config["intermediate_size"] + hidden_act = input.extra_benchmark_config["hidden_act"] + dtype = input.extra_benchmark_config["dtype"] + num_shards = input.extra_benchmark_config.get("num_shards", None) + activation_type = input.extra_benchmark_config["activation_type"] + provider = input.kernel_provider + mode = input.kernel_operation_mode + + llama_config = LlamaConfig( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + hidden_act=hidden_act, + ) + + x_shape = (bsz, seq_len, hidden_size) + # initialize input + x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True) + + if activation_type == "geglu": + if provider == "huggingface": + layer = LlamaMLP(config=llama_config).to(device).to(dtype) + elif provider == "liger": + layer = LigerGEGLUMLP(config=llama_config).to(device).to(dtype) + elif provider == "liger_tiled": + layer = LigerTiledGEGLUMLP(config=llama_config, num_shards=num_shards).to(device).to(dtype) + elif provider == "deepspeed_tiled": + layer = DeepSpeedTiledMLPWrapper(config=llama_config, num_shards=num_shards).to(device).to(dtype) + else: + raise ValueError(f"Invalid provider: {provider} for GEGLU") + elif activation_type == "swiglu": + if provider == "huggingface": + layer = LlamaMLP(config=llama_config).to(device).to(dtype) + elif provider == "liger": + layer = LigerSwiGLUMLP(config=llama_config).to(device).to(dtype) + elif provider == "liger_tiled": + layer = LigerTiledSwiGLUMLP(config=llama_config, num_shards=num_shards).to(device).to(dtype) + elif provider == "deepspeed_tiled": + layer = DeepSpeedTiledMLPWrapper(config=llama_config, num_shards=num_shards).to(device).to(dtype) + else: + raise ValueError(f"Invalid provider: {provider} for SwiGLU") + else: + raise ValueError(f"Invalid activation_type: {activation_type}") + + def fwd(): + return layer(x) + + def full(): + y = fwd() + y.backward(torch.randn_like(y), retain_graph=True) + + if mode == "forward": + mem_50, mem_20, mem_80 = _test_memory( + fwd, + quantiles=QUANTILES, + ) + elif mode == "backward": + do = torch.randn_like(x) + y = fwd() + mem_50, mem_20, mem_80 = _test_memory( + lambda: y.backward(do, retain_graph=True), + quantiles=QUANTILES, + ) + else: + mem_50, mem_20, mem_80 = _test_memory( + full, + quantiles=QUANTILES, + ) + + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + # Benchmark GEGLU variants + kernel_providers_geglu = ["huggingface", "liger", "liger_tiled", "deepspeed_tiled"] + + common_configs_geglu = { + "kernel_name": "tiled_geglu", + "x_name": "T", + "x_label": "sequence length", + "x_values": [2**i for i in range(10, 15)], # 1024 to 16384 + "kernel_providers": kernel_providers_geglu, + "extra_benchmark_configs": [ + { + "bsz": 2, + "hidden_size": 2048, + "intermediate_size": 4096, + "hidden_act": "gelu_pytorch_tanh", + "activation_type": "geglu", + "num_shards": 4, + "dtype": torch.bfloat16, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_tiled_mlp, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs_geglu, + ) + run_benchmarks( + bench_test_fn=bench_memory_tiled_mlp, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="memory", + metric_unit="MB", + **common_configs_geglu, + ) + + # Benchmark SwiGLU variants + kernel_providers_swiglu = ["huggingface", "liger", "liger_tiled", "deepspeed_tiled"] + + common_configs_swiglu = { + "kernel_name": "tiled_swiglu", + "x_name": "T", + "x_label": "sequence length", + "x_values": [2**i for i in range(10, 15)], # 1024 to 16384 + "kernel_providers": kernel_providers_swiglu, + "extra_benchmark_configs": [ + { + "bsz": 2, + "hidden_size": 2048, + "intermediate_size": 4096, + "hidden_act": "silu", + "activation_type": "swiglu", + "num_shards": 4, + "dtype": torch.bfloat16, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_tiled_mlp, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs_swiglu, + ) + run_benchmarks( + bench_test_fn=bench_memory_tiled_mlp, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="memory", + metric_unit="MB", + **common_configs_swiglu, + ) diff --git a/benchmark/scripts/benchmark_tvd.py b/benchmark/scripts/benchmark_tvd.py new file mode 100755 index 0000000000000000000000000000000000000000..ef76380a2664a658a2e19c1815fa03cd8a47a5d9 --- /dev/null +++ b/benchmark/scripts/benchmark_tvd.py @@ -0,0 +1,145 @@ +import torch +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.transformers.tvd import LigerTVDLoss +from liger_kernel.utils import get_total_gpu_memory +from liger_kernel.utils import infer_device + +device = infer_device() + + +class TorchTVDLoss(torch.nn.Module): + def __init__(self, reduction="batchmean"): + super(TorchTVDLoss, self).__init__() + self.reduction = reduction + + def forward(self, p, q): + tvd = torch.abs(p - q) / 2.0 + if self.reduction == "mean": + return torch.sum(tvd) / (p.size(0) * p.size(1)) + elif self.reduction == "sum": + return torch.sum(tvd) + elif self.reduction == "none": + return tvd + elif self.reduction == "batchmean": + return torch.sum(tvd) / p.size(0) + else: + raise ValueError("Invalid reduction type.") + + +S, E = 12, 18 + + +def bench_speed_tvd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + reduction = "batchmean" + V = input.x + B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] + torch_tvd = TorchTVDLoss(reduction=reduction) + liger_tvd = LigerTVDLoss(reduction=reduction) + + _input = torch.randn(B * T, V, requires_grad=True, device=device).softmax(dim=-1) + target = torch.randn(B * T, V, device=device).softmax(dim=-1) + + def fwd(): + if input.kernel_provider == "liger": + return liger_tvd(_input, target) + else: + return torch_tvd(_input, target) + + if input.kernel_operation_mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100) + elif input.kernel_operation_mode == "backward": + y = fwd() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + quantiles=QUANTILES, + grad_to_none=[_input], + rep=100, + ) + elif input.kernel_operation_mode == "full": + + def full(): + y = fwd() + y.backward(retain_graph=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100) + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +def bench_memory_tvd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + reduction = "batchmean" + torch_tvd = TorchTVDLoss(reduction=reduction) + liger_tvd = LigerTVDLoss(reduction=reduction) + + V = input.x + B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] + + _input = torch.randn(B * T, V, requires_grad=True, device=device).softmax(dim=-1) + target = torch.randn(B * T, V, device=device).softmax(dim=-1) + + def fwd(): + if input.kernel_provider == "liger": + return liger_tvd(_input, target) + else: + return torch_tvd(_input, target) + + def full(): + y = fwd() + y.backward(retain_graph=True) + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + gpu_memory_gbs = get_total_gpu_memory() + # We know that the full test will require 66GBs for vocab size 2^17 + if gpu_memory_gbs >= 66: + x_max = 17 + elif gpu_memory_gbs >= 32: + x_max = 16 + else: + x_max = 15 + common_args = { + "kernel_name": "tvd", + "x_name": "V", + "x_label": "vocab size", + "x_values": [2**i for i in range(12, x_max + 1)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [{"B": 8, "T": 2048}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_memory_tvd, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_args, + ) + + run_benchmarks( + bench_test_fn=bench_speed_tvd, + kernel_operation_modes=["forward", "full", "backward"], + metric_name="speed", + metric_unit="ms", + **common_args, + ) diff --git a/benchmark/scripts/utils.py b/benchmark/scripts/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..e6b4fc9e85d03673729fb0edf8d66e315a2d0b16 --- /dev/null +++ b/benchmark/scripts/utils.py @@ -0,0 +1,439 @@ +import argparse +import csv +import json +import os +import time + +from collections import OrderedDict +from dataclasses import asdict +from dataclasses import dataclass +from importlib.metadata import version +from itertools import zip_longest +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + +import torch + +from liger_kernel.utils import infer_device + +device = infer_device() + +LIGER_KERNEL_VERSION = version("liger-kernel") + +QUANTILES = [0.5, 0.2, 0.8] + + +@dataclass +class SingleBenchmarkRunInput: + x: Union[int, float] + kernel_provider: str + kernel_operation_mode: Optional[str] = "" + extra_benchmark_config: Optional[Dict[str, Any]] = None + + +@dataclass +class SingleBenchmarkRunOutput: + # 20th percentile + y_20: float + # 50th percentile (median) + y_50: float + # 80th percentile + y_80: float + + +@dataclass +class BenchmarkData: + """ + BenchmarkData is a dataclass to store the benchmark data for a a completed benchmark + run on all x-values for a given kernel/kernel operation mode/metric/extra_benchmark_config + """ + + kernel_name: str + kernel_provider: str + metric_name: str + metric_unit: str + gpu_name: str + x_name: str + x_label: str + x_values: List[float] + y_values_50: List[float] + y_values_20: List[float] + y_values_80: List[float] + timestamp: str + kernel_operation_mode: Optional[str] = None + extra_benchmark_config_str: Optional[str] = None + liger_version: str = LIGER_KERNEL_VERSION + + +@dataclass +class BenchmarkDataCSVRow: + # The ordering of field names here will be the order of columns in the CSV + kernel_name: str + kernel_provider: str + kernel_operation_mode: Union[str, None] + metric_name: str + metric_unit: str + x_name: str + x_label: str + x_value: float + y_value_50: float + y_value_20: float + y_value_80: float + extra_benchmark_config_str: Union[str, None] + gpu_name: str + timestamp: str + liger_version: str + + +def _test_memory( + func: Callable, + _iter: int = 10, + quantiles: Optional[List[float]] = None, + return_mode="mean", +) -> float: + assert return_mode in ["min", "max", "mean", "median"] + total_mem = [] + + for _ in range(_iter): + getattr(torch, device).memory.reset_peak_memory_stats() + func() + # Convert to MB + mem = getattr(torch, device).max_memory_allocated() / 2**20 + total_mem.append(mem) + + total_mem = torch.tensor(total_mem, dtype=torch.float) + if quantiles is not None: + quantiles_data = torch.quantile(total_mem, torch.tensor(quantiles, dtype=torch.float)).tolist() + if len(quantiles_data) == 1: + quantiles_data = quantiles_data[0] + return quantiles_data + return getattr(torch, return_mode)(total_mem).item() + + +def run_speed_benchmark( + fwd_fn: Callable, + mode: str, + input_tensors: List[torch.Tensor], + rep: int = 10, +) -> "SingleBenchmarkRunOutput": + """Measure execution speed for forward, backward, or full (fwd+bwd). + + Covers the common case where the forward function returns a single tensor + and backward uses a random gradient of the same shape. For kernels with + scalar output (losses) or multiple outputs (e.g. RoPE), write custom + measurement logic instead. + """ + import triton + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd_fn, + grad_to_none=input_tensors, + rep=rep, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = fwd_fn() + do = torch.randn_like(y) + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(do, retain_graph=True), + grad_to_none=input_tensors, + rep=rep, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd_fn() + y.backward(torch.randn_like(y), retain_graph=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + grad_to_none=input_tensors, + rep=rep, + quantiles=QUANTILES, + ) + else: + raise ValueError(f"Unsupported mode: {mode}. Use 'forward', 'backward', or 'full'.") + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def run_memory_benchmark( + fwd_fn: Callable, + mode: str, +) -> "SingleBenchmarkRunOutput": + """Measure peak memory for forward, backward, or full (fwd+bwd). + + Same caveats as :func:`run_speed_benchmark` regarding output shape. + """ + if mode == "forward": + mem_50, mem_20, mem_80 = _test_memory(fwd_fn, quantiles=QUANTILES) + elif mode == "backward": + y = fwd_fn() + do = torch.randn_like(y) + mem_50, mem_20, mem_80 = _test_memory( + lambda: y.backward(do, retain_graph=True), + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd_fn() + y.backward(torch.randn_like(y), retain_graph=True) + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + else: + raise ValueError(f"Unsupported mode: {mode}. Use 'forward', 'backward', or 'full'.") + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +def get_current_file_directory() -> str: + """ + Returns the directory path of the current Python file. + """ + # Get the absolute path of the current file + current_file_path = os.path.abspath(__file__) + + # Get the directory path of the current file + return os.path.dirname(current_file_path) + + +def sleep(seconds): + def decorator(function): + def wrapper(*args, **kwargs): + time.sleep(seconds) + return function(*args, **kwargs) + + return wrapper + + return decorator + + +def _print_benchmarking_banner(metric_name: str, kernel_name: str): + print("**************************************") + print(f" BENCHMARKING {metric_name.upper()} for {kernel_name.upper()}") + print("**************************************") + + +def get_formatted_time(): + return time.strftime("%Y-%m-%d %H:%M:%S") + + +def get_gpu_name(): + """ + Returns the current GPU name, formatted to serve as a directory name + """ + torch_device = getattr(torch, device) + if torch_device.is_available(): + gpu_name = torch_device.get_device_name(torch_device.current_device()) + return gpu_name + else: + raise Exception("Benchmarks can only be run on GPU.") + + +def update_benchmark_data_csv( + benchmark_data_list: List[BenchmarkData], + filename: str = "all_benchmark_data.csv", + overwrite: bool = True, +): + """ + Update the CSV file with the new benchmark data. If the file does not exist, create it. + If an entry already exists for the benchmark, then overwrite it if `overwrite` is True. + """ + + def create_unique_key(row): + # This unique key is used to determine if a benchmark run already exists in the CSV + # If the key is the same, then the benchmark run already exists and will optionally + # be overwritten. Otherwise, it is considered a new benchmark run and appended. + return ( + row["kernel_name"], + row["kernel_provider"], + row["kernel_operation_mode"] if row["kernel_operation_mode"] else "", + row["metric_name"], + row["x_name"], + str(row["x_value"]), + (row["extra_benchmark_config_str"] if row["extra_benchmark_config_str"] else ""), + row["gpu_name"], + ) + + fieldnames = BenchmarkDataCSVRow.__annotations__.keys() + + # Make filename path relative to current file + filename_abs_path = os.path.join(get_current_file_directory(), "../data", filename) + file_exists = os.path.isfile(filename_abs_path) + + # Read existing data into a list of dicts + existing_data = [] + if file_exists: + with open(filename_abs_path, mode="r") as file: + reader = csv.DictReader(file) + for row in reader: + existing_data.append(row) + + existing_data_dict = OrderedDict((create_unique_key(row), row) for row in existing_data) + + for benchmark_data in benchmark_data_list: + benchmark_data_dict = asdict(benchmark_data) + x_values = benchmark_data_dict.pop("x_values") + y_values_50 = benchmark_data_dict.pop("y_values_50") + y_values_20 = benchmark_data_dict.pop("y_values_20") + y_values_80 = benchmark_data_dict.pop("y_values_80") + + # Need to convert benchmark_data into multiple rows based on x_values and y_values + for x_value, y_value_50, y_value_20, y_value_80 in zip_longest(x_values, y_values_50, y_values_20, y_values_80): + if y_value_50 is None: + y_value_50 = float("nan") + if y_value_20 is None: + y_value_20 = float("nan") + if y_value_80 is None: + y_value_80 = float("nan") + + row = BenchmarkDataCSVRow( + x_value=x_value, + y_value_50=y_value_50, + y_value_20=y_value_20, + y_value_80=y_value_80, + **benchmark_data_dict, + ) + row_dict = asdict(row) + + row_key = create_unique_key(row_dict) + + if row_key in existing_data_dict: + if overwrite: + # If overwriting, update the row + existing_data_dict[row_key] = row_dict + else: + # If not overwriting, skip this row + pass + else: + existing_data_dict[row_key] = row_dict + os.makedirs(os.path.dirname(filename_abs_path), exist_ok=True) + with open(filename_abs_path, mode="w", newline="") as file: + writer = csv.DictWriter(file, fieldnames=fieldnames) + writer.writeheader() + + for row in existing_data_dict.values(): + writer.writerow(row) + + +class CustomEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, torch.dtype): + return str(obj) + return super().default(self, obj) + + +def print_benchmark_data(benchmark_data_list: List[BenchmarkData]) -> str: + print("********** Benchmark Data **********") + formatted_list = [obj.__dict__ for obj in benchmark_data_list] + print(json.dumps(formatted_list, indent=2)) + + +def run_benchmarks( + bench_test_fn: Callable, + kernel_name: str, + metric_name: str, + metric_unit: str, + x_name: str, + x_label: str, + x_values: List[Union[float, int]], + kernel_providers: List[str], + kernel_operation_modes: Optional[List[str]] = [None], + extra_benchmark_configs: Optional[List[Dict[str, Any]]] = None, + overwrite: bool = False, +): + """ + Run benchmarks given a bench_test_fn that takes in a SingleBenchmarkRunInput as input and + saves data to the CSV file. + + Args: + - bench_test_fn: The benchmark test function to run. This function should take in a + SingleBenchmarkRunInput as input and return a SingleBenchmarkRunOutput. + - kernel_name: The name of the kernel being benchmarked (e.g. "swiglu") + - metric_name: The name of the metric being benchmarked (e.g. "speed" or "memory") + - metric_unit: The unit of the metric being benchmarked (e.g. "ms" or "MB") + - x_name: The name of the x-axis (e.g. "T" for sequence length) + - x_label: The label of the x-axis (e.g. "sequence length") + - x_values: The list of x-values to run the benchmark on (e.g. [2**i for i in range(10, 14)]) + - kernel_providers: The list of kernel providers to run the benchmark on (e.g. ["liger", "huggingface"]) + - kernel_operation_modes: The list of kernel operation modes to run the benchmark on (e.g. ["full", "backward"]) + - extra_benchmark_configs: The list of extra benchmark configurations to run the benchmark on. + - overwrite: Whether to overwrite the existing benchmark data entry if it already exists. + """ + + assert len(kernel_operation_modes) >= 1 + assert len(kernel_providers) >= 1 + + _print_benchmarking_banner(metric_name=metric_name, kernel_name=kernel_name) + + gpu_name = get_gpu_name() + benchmark_data_list = [] + for extra_benchmark_config in extra_benchmark_configs: + for kernel_operation_mode in kernel_operation_modes: + for kernel_provider in kernel_providers: + y_values_50 = [] + y_values_20 = [] + y_values_80 = [] + + for x in x_values: + single_benchmark_run_input = SingleBenchmarkRunInput( + x=x, + kernel_provider=kernel_provider, + kernel_operation_mode=kernel_operation_mode, + extra_benchmark_config=extra_benchmark_config, + ) + benchmark_result: SingleBenchmarkRunOutput = bench_test_fn(single_benchmark_run_input) + y_values_50.append(benchmark_result.y_50) + y_values_20.append(benchmark_result.y_20) + y_values_80.append(benchmark_result.y_80) + + benchmark_run_data = BenchmarkData( + kernel_name=kernel_name, + kernel_operation_mode=kernel_operation_mode, + kernel_provider=kernel_provider, + metric_name=metric_name, + metric_unit=metric_unit, + gpu_name=gpu_name, + x_name=x_name, + x_label=x_label, + x_values=x_values, + y_values_50=y_values_50, + y_values_20=y_values_20, + y_values_80=y_values_80, + extra_benchmark_config_str=json.dumps(extra_benchmark_config, cls=CustomEncoder), + timestamp=get_formatted_time(), + liger_version=LIGER_KERNEL_VERSION, + ) + + benchmark_data_list.append(benchmark_run_data) + + print_benchmark_data(benchmark_data_list) + + update_benchmark_data_csv(benchmark_data_list=benchmark_data_list, overwrite=overwrite) + + +def parse_benchmark_script_args(): + parser = argparse.ArgumentParser(description="Benchmarking script for Liger-Kernel") + + parser.add_argument( + "--overwrite", + action="store_true", + help="Flag to overwrite existing benchmark data with current run.", + ) + parser.add_argument( + "--model", + type=str, + default=None, + help=( + "Model config name from MODEL_REGISTRY " + "(e.g. llama_2_7b, llama_3_8b). " + "Defaults to llama_3_8b when not specified." + ), + ) + args = parser.parse_args() + return args diff --git a/dev/fmt-requirements.txt b/dev/fmt-requirements.txt new file mode 100755 index 0000000000000000000000000000000000000000..1d8f48692ce6c44eacf156ed5027a1d330741ba4 --- /dev/null +++ b/dev/fmt-requirements.txt @@ -0,0 +1 @@ +ruff>=0.1.6 diff --git a/dev/modal/benchmarks.py b/dev/modal/benchmarks.py new file mode 100755 index 0000000000000000000000000000000000000000..a54fa47d88ba93a9c092f2a062f06a07eb66d992 --- /dev/null +++ b/dev/modal/benchmarks.py @@ -0,0 +1,73 @@ +from pathlib import Path + +import modal + +ROOT_PATH = Path(__file__).parent.parent.parent +REMOTE_ROOT_PATH = "/root/liger-kernel" +PYTHON_VERSION = "3.12" + +image = modal.Image.debian_slim(python_version=PYTHON_VERSION).pip_install("uv") + +app = modal.App("liger_benchmarks", image=image) + +# mount: add local files to the remote container +repo = image.add_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH) + + +@app.function(gpu="H100!", image=repo, timeout=60 * 90) +def liger_benchmarks(): + import os + import subprocess + + subprocess.run( + ["uv pip install -e '.[dev]' --system"], + check=True, + shell=True, + cwd=REMOTE_ROOT_PATH, + ) + subprocess.run(["make run-benchmarks"], check=True, shell=True, cwd=REMOTE_ROOT_PATH) + + file_path = Path(REMOTE_ROOT_PATH) / "benchmark" / "data" / "all_benchmark_data.csv" + print(f"Checking if file exists at: {file_path}") + print(f"File exists: {os.path.exists(file_path)}") + + if not os.path.exists(file_path): + print("Listing directory contents:") + data_dir = file_path.parent + if os.path.exists(data_dir): + print(f"Contents of {data_dir}:") + print(os.listdir(data_dir)) + else: + print(f"Data directory {data_dir} does not exist") + raise FileNotFoundError(f"Benchmark data file not found at {file_path}") + + with open(file_path, "rb") as f: + data = f.read() + print(f"Successfully read {len(data)} bytes of data") + return data + + +@app.local_entrypoint() +def main(): + try: + # Run the benchmarks and get the data + print("Starting benchmark run...") + benchmark_data = liger_benchmarks.remote() + + if not benchmark_data: + raise ValueError("No data received from remote function") + + # Save the data locally + local_data_path = ROOT_PATH / "benchmark" / "data" / "all_benchmark_data.csv" + print(f"Attempting to save data to: {local_data_path}") + + local_data_path.parent.mkdir(parents=True, exist_ok=True) + + with open(local_data_path, "wb") as f: + f.write(benchmark_data) + + print(f"Successfully saved {len(benchmark_data)} bytes to: {local_data_path}") + + except Exception as e: + print(f"Error occurred: {str(e)}") + raise diff --git a/dev/modal/tests.py b/dev/modal/tests.py new file mode 100755 index 0000000000000000000000000000000000000000..07856dcecb1af92fc260e780b65b8b4c10421eeb --- /dev/null +++ b/dev/modal/tests.py @@ -0,0 +1,86 @@ +from pathlib import Path + +import modal + +ROOT_PATH = Path(__file__).parent.parent.parent +REMOTE_ROOT_PATH = "/root/liger-kernel" +PYTHON_VERSION = "3.12" + +OLDEST_SUPPORTED_TRANSFORMERS_V4_VERSION = "4.52.0" + +image = modal.Image.debian_slim(python_version=PYTHON_VERSION).pip_install("uv") + +app = modal.App("liger_tests", image=image) + +# mount: add local files to the remote container +repo = image.add_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH) + + +@app.function(gpu="H100!", image=repo, timeout=90 * 60) +def liger_correctness_tests(): + import subprocess + + subprocess.run( + ["uv pip install -e '.[dev]' --system"], + check=True, + shell=True, + cwd=REMOTE_ROOT_PATH, + ) + subprocess.run(["make test"], check=True, shell=True, cwd=REMOTE_ROOT_PATH) + + +@app.function(gpu="H100!", image=repo, timeout=90 * 60) +def liger_convergence_tests(): + import subprocess + + subprocess.run( + ["uv pip install -e '.[dev]' --system"], + check=True, + shell=True, + cwd=REMOTE_ROOT_PATH, + ) + subprocess.run(["make test-convergence"], check=True, shell=True, cwd=REMOTE_ROOT_PATH) + + +oldest_v4_app = modal.App("liger_oldest_v4_tests", image=image) # 4.52.0 + + +@oldest_v4_app.function(gpu="H100!", image=repo, timeout=90 * 60) +def liger_oldest_v4_correctness_tests(): + import subprocess + + subprocess.run( + ["uv pip install -e '.[dev]' --system"], + check=True, + shell=True, + cwd=REMOTE_ROOT_PATH, + ) + subprocess.run( + [f"uv pip install 'transformers=={OLDEST_SUPPORTED_TRANSFORMERS_V4_VERSION}' --system"], + check=True, + shell=True, + cwd=REMOTE_ROOT_PATH, + ) + subprocess.run(["make test"], check=True, shell=True, cwd=REMOTE_ROOT_PATH) + + +@oldest_v4_app.function(gpu="H100!", image=repo, timeout=90 * 60) +def liger_oldest_v4_convergence_tests(): + import subprocess + + subprocess.run( + ["uv pip install -e '.[dev]' --system"], + check=True, + shell=True, + cwd=REMOTE_ROOT_PATH, + ) + subprocess.run( + [f"uv pip install 'transformers=={OLDEST_SUPPORTED_TRANSFORMERS_V4_VERSION}' --system"], + check=True, + shell=True, + cwd=REMOTE_ROOT_PATH, + ) + subprocess.run(["make test-convergence"], check=True, shell=True, cwd=REMOTE_ROOT_PATH) + + +latest_v4_app = modal.App("liger_latest_v4_tests", image=image) # 4.57.6 diff --git a/docs/Examples.md b/docs/Examples.md new file mode 100755 index 0000000000000000000000000000000000000000..41a1fb92fab04c9b4698fbdfa030f4619144f83a --- /dev/null +++ b/docs/Examples.md @@ -0,0 +1,268 @@ + +!!! Example "HANDS-ON USECASE EXAMPLES" +| **Use Case** | **Description** | +|------------------------------------------------|---------------------------------------------------------------------------------------------------| +| [**Hugging Face Trainer**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/huggingface) | Train LLaMA 3-8B ~20% faster with over 40% memory reduction on Alpaca dataset using 4 A100s with FSDP | +| [**Lightning Trainer**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/lightning) | Increase 15% throughput and reduce memory usage by 40% with LLaMA3-8B on MMLU dataset using 8 A100s with DeepSpeed ZeRO3 | +| [**Medusa Multi-head LLM (Retraining Phase)**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) | Reduce memory usage by 80% with 5 LM heads and improve throughput by 40% using 8 A100s with FSDP | +| [**Vision-Language Model SFT**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/huggingface/run_qwen2_vl.sh) | Finetune Qwen2-VL on image-text data using 4 A100s with FSDP | +| [**Liger ORPO Trainer**](https://github.com/linkedin/Liger-Kernel/blob/main/examples/alignment/run_orpo.py) | Align Llama 3.2 using Liger ORPO Trainer with FSDP with 50% memory reduction | + +## HuggingFace Trainer + +### How to Run + +#### Locally on a GPU machine +You can run the example locally on a GPU machine. The default hyperparameters and configurations work on single node with 4xA100 80GB GPUs and FSDP. + +!!! Example + +```bash +pip install -r requirements.txt +sh run_{MODEL}.sh +``` + +#### Remotely on Modal +If you do not have access to a GPU machine, you can run the example on Modal. Modal is a serverless platform that allows you to run your code on a remote GPU machine. You can sign up for a free account at [Modal](https://www.modal.com/). + +!!! Example + +```bash +pip install modal +modal setup # authenticate with Modal +modal run launch_on_modal.py --script "run_qwen2_vl.sh" +``` + +!!! Notes + +1. This example uses an optional `use_liger` flag. If true, it does a 1 line monkey patch to apply liger kernel. + +2. The example uses Llama3 model that requires community license agreement and HuggingFace Hub login. If you want to use Llama3 in this example, please make sure you have done the following: + * Agree on the [community license agreement](https://huggingface.co/meta-llama/Meta-Llama-3-8B) . + * Run `huggingface-cli login` and enter your HuggingFace token. + +3. The default hyperparameters and configurations work on single node with 4xA100 80GB GPUs. For running on device with less GPU RAM, please consider reducing the per-GPU batch size and/or enable `CPUOffload` in FSDP. + + +### Benchmark Result + +### Llama + +!!! Info +>Benchmark conditions: +>Model= LLaMA 3-8B,Datset= Alpaca, Max seq len = 512, Data Type = bf16, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 4 A100s. + +Throughput improves by around 20%, while GPU memory usage drops by 40%. This allows you to train the model on smaller GPUs, use larger batch sizes, or handle longer sequence lengths without incurring additional costs. + +![Throughput](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/examples/huggingface/img/llama_tps.png) +![GPU Memory Allocated](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/examples/huggingface/img/llama_mem_alloc.png) + +### Qwen + +!!! Info +>Benchmark conditions: +>Model= Qwen2-7B, Dataset= Alpaca, Max seq len = 512, Data Type = bf16, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 4 A100s. + +Throughput improves by around 10%, while GPU memory usage drops by 50%. + +![Throughput](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/examples/huggingface/img/qwen_tps.png) +![GPU Memory Allocated](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/examples/huggingface/img/qwen_mem_alloc.png) + + +### Gemma 7B + +!!! Info +>Benchmark conditions: +> Model= Gemma-7B, Dataset= Alpaca, Max seq len = 512, Data Type = bf16, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 4 A100s. + +Throughput improves by around 24%, while GPU memory usage drops by 33%. + +![Throughput](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/examples/huggingface/img/gemma_7b_mem.png) +![GPU Memory Allocated](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/examples/huggingface/img/gemma_7b_tp.png) + +## Lightning Trainer + +### How to Run + +#### Locally on a GPU machine +You can run the example locally on a GPU machine. + +!!! Example + +```bash +pip install -r requirements.txt + +# For single L40 48GB GPU +python training.py --model Qwen/Qwen2-0.5B-Instruct --num_gpu 1 --max_length 1024 + +# For 8XA100 40GB +python training.py --model meta-llama/Meta-Llama-3-8B --strategy deepspeed +``` + +!!! Notes + +1. The example uses Llama3 model that requires community license agreement and HuggingFace Hub login. If you want to use Llama3 in this example, please make sure you have done the following: + * Agree on the [community license agreement](https://huggingface.co/meta-llama/Meta-Llama-3-8B) + * Run `huggingface-cli login` and enter your HuggingFace token. + +2. The default hyperparameters and configurations for gemma works on single L40 48GB GPU and config for llama work on single node with 8xA100 40GB GPUs. For running on device with less GPU RAM, please consider reducing the per-GPU batch size and/or enable `CPUOffload` in FSDP. + +## Medusa + +Medusa is a simple framework that democratizes the acceleration techniques for LLM generation with multiple decoding heads. To know more, you can check out the [repo](https://arxiv.org/abs/2401.10774) and the [paper](https://arxiv.org/abs/2401.10774) . + +The Liger fused CE kernel is highly effective in this scenario, eliminating the need to materialize logits for each head, which usually consumes a large volume of memory due to the extensive vocabulary size (e.g., for LLaMA-3, the vocabulary size is 128k). + +The introduction of multiple heads can easily lead to OOM (Out of Memory) issues. However, thanks to the efficient Liger fused CE, which calculates the gradient in place and doesn't materialize the logits, we have observed very effective results. This efficiency opens up more opportunities for multi-token prediction research and development. + + +### How to Run + +!!! Example + +```bash +git clone git@github.com:linkedin/Liger-Kernel.git +cd {PATH_TO_Liger-Kernel}/Liger-Kernel/ +pip install -e . +cd {PATH_TO_Liger-Kernel}/Liger-Kernel/examples/medusa +pip install -r requirements.txt +sh scripts/llama3_8b_medusa.sh +``` + +!!! Notes + +1. This example uses an optional `use_liger` flag. If true, it does a monkey patch to apply liger kernel with medusa heads. + +2. The example uses Llama3 model that requires community license agreement and HuggingFace Hub login. If you want to use Llama3 in this example, please make sure you have done the followings: + * Agree on the community license agreement https://huggingface.co/meta-llama/Meta-Llama-3-8B + * Run `huggingface-cli login` and enter your HuggingFace token + +3. The default hyperparameters and configurations work on single node with 8xA100 GPUs. For running on device with less GPU RAM, please consider reducing the per-GPU batch size and/or enable `CPUOffload` in FSDP. + +4. We are using a smaller sample of shared GPT data primarily to benchmark performance. The example requires hyperparameter tuning and dataset selection to work effectively, also ensuring the dataset has the same distribution as the LLaMA pretraining data. Welcome contribution to enhance the example code. + +### Benchmark Result + +!!! Info +> 1. Benchmark conditions: LLaMA 3-8B, Batch Size = 6, Data Type = bf16, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s. + +#### Stage 1 + +Stage 1 refers to Medusa-1 where the backbone model is frozen and only weights of LLM heads are updated. + +!!! Warning +```bash +# Modify this flag in llama3_8b_medusa.sh to True enables stage1 +--medusa_only_heads True +``` + +#### num_head = 3 + +![Memory](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/examples/medusa/docs/images/Memory_Stage1_num_head_3.png) +![Throughput](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png) + +#### num_head = 5 + +![Memory](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/examples/medusa/docs/images/Memory_Stage1_num_head_5.png) +![Throughput](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png) + +#### Stage 2 + +!!! Warning +```bash +# Modify this flag to False in llama3_8b_medusa.sh enables stage2 +--medusa_only_heads False +``` + +Stage 2 refers to Medusa-2 where all the model weights are updated including the backbone model and llm heads. + +#### num_head = 3 + +![Memory](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/examples/medusa/docs/images/Memory_Stage2_num_head_3.png) +![Throughput](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png) + +#### num_head = 5 + +![Memory](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/examples/medusa/docs/images/Memory_Stage2_num_head_5.png) +![Throughput](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png) + + +## Vision-Language Model SFT + +## How to Run + +### Locally on a GPU Machine +You can run the example locally on a GPU machine. The default hyperparameters and configurations work on single node with 4xA100 80GB GPUs. + +!!! Example +```bash +#!/bin/bash + +torchrun --nnodes=1 --nproc-per-node=4 training_multimodal.py \ + --model_name "Qwen/Qwen2-VL-7B-Instruct" \ + --bf16 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 8 \ + --per_device_eval_batch_size 8 \ + --eval_strategy "no" \ + --save_strategy "no" \ + --learning_rate 6e-6 \ + --weight_decay 0.05 \ + --warmup_ratio 0.1 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --include_num_input_tokens_seen \ + --report_to none \ + --fsdp "full_shard auto_wrap" \ + --fsdp_config config/fsdp_config.json \ + --seed 42 \ + --use_liger True \ + --output_dir multimodal_finetuning +``` + +## ORPO Trainer + +### How to Run + +#### Locally on a GPU Machine + +You can run the example locally on a GPU machine and FSDP. + +!!! Example +```py +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from trl import ORPOConfig # noqa: F401 + +from liger_kernel.transformers.trainer import LigerORPOTrainer # noqa: F401 + +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-3.2-1B-Instruct", + dtype=torch.bfloat16, +) + +tokenizer = AutoTokenizer.from_pretrained( + "meta-llama/Llama-3.2-1B-Instruct", + max_length=512, + padding="max_length", +) +tokenizer.pad_token = tokenizer.eos_token + +train_dataset = load_dataset("trl-lib/tldr-preference", split="train") + +training_args = ORPOConfig( + output_dir="Llama3.2_1B_Instruct", + beta=0.1, + max_length=128, + per_device_train_batch_size=32, + max_steps=100, + save_strategy="no", +) + +trainer = LigerORPOTrainer( + model=model, args=training_args, tokenizer=tokenizer, train_dataset=train_dataset +) + +trainer.train() +``` \ No newline at end of file diff --git a/docs/Getting-Started.md b/docs/Getting-Started.md new file mode 100755 index 0000000000000000000000000000000000000000..3b6af54777479a2ad7078b338cf1da138754be0d --- /dev/null +++ b/docs/Getting-Started.md @@ -0,0 +1,64 @@ +There are a couple of ways to apply Liger kernels, depending on the level of customization required. + +### 1. Use AutoLigerKernelForCausalLM + +Using the `AutoLigerKernelForCausalLM` is the simplest approach, as you don't have to import a model-specific patching API. If the model type is supported, the modeling code will be automatically patched using the default settings. + +!!! Example + + ```python + from liger_kernel.transformers import AutoLigerKernelForCausalLM + + # This AutoModel wrapper class automatically monkey-patches the + # model with the optimized Liger kernels if the model is supported. + model = AutoLigerKernelForCausalLM.from_pretrained("path/to/some/model") + ``` + +### 2. Apply Model-Specific Patching APIs + +Using the [patching APIs](https://github.com/linkedin/Liger-Kernel?tab=readme-ov-file#patching), you can swap Hugging Face models with optimized Liger Kernels. + +!!! Example + +```python +import transformers +from liger_kernel.transformers import apply_liger_kernel_to_llama + +# 1a. Adding this line automatically monkey-patches the model with the optimized Liger kernels +apply_liger_kernel_to_llama() + +# 1b. You could alternatively specify exactly which kernels are applied +apply_liger_kernel_to_llama( + rope=True, + swiglu=True, + cross_entropy=True, + fused_linear_cross_entropy=False, + rms_norm=False +) + +# 2. Instantiate patched model +model = transformers.AutoModelForCausalLM("path/to/llama/model") +``` + +### 3. Compose Your Own Model + +You can take individual [kernels](https://github.com/linkedin/Liger-Kernel?tab=readme-ov-file#model-kernels) to compose your models. + +!!! Example + +```python +from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss +import torch.nn as nn +import torch + +model = nn.Linear(128, 256).cuda() + +# fuses linear + cross entropy layers together and performs chunk-by-chunk computation to reduce memory +loss_fn = LigerFusedLinearCrossEntropyLoss() + +input = torch.randn(4, 128, requires_grad=True, device="cuda") +target = torch.randint(256, (4, ), device="cuda") + +loss = loss_fn(model.weight, input, target) +loss.backward() +``` \ No newline at end of file diff --git a/docs/High-Level-APIs.md b/docs/High-Level-APIs.md new file mode 100755 index 0000000000000000000000000000000000000000..5433e03d38f05d3078361f950a7fd4fbc8bf0598 --- /dev/null +++ b/docs/High-Level-APIs.md @@ -0,0 +1,93 @@ +# High-Level APIs + +## AutoModel + +| **AutoModel Variant** | **API** | +|------------------------|---------| +| AutoModelForCausalLM | `liger_kernel.transformers.AutoLigerKernelForCausalLM` | + +This API extends the implementation of the `AutoModelForCausalLM` within the `transformers` library from Hugging Face. + +::: liger_kernel.transformers.AutoLigerKernelForCausalLM + options: + extra: + show_docstring: true + show_signature: true + show_source: true + +!!! Example "Try it Out" + You can experiment as shown in this example [here](https://github.com/linkedin/Liger-Kernel?tab=readme-ov-file#1-use-autoligerkernelforcausallm). + +--- + +## Patching + +You can also use the Patching APIs to use the kernels for a specific model architecture. + +| **Model** | **API** | **Supported Operations** | +|-------------|--------------------------------------------------------------|-------------------------------------------------------------------------| +| LLaMA 2 & 3 | `liger_kernel.transformers.apply_liger_kernel_to_llama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | +| LLaMA 3.2-Vision | `liger_kernel.transformers.apply_liger_kernel_to_mllama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | +| Mistral | `liger_kernel.transformers.apply_liger_kernel_to_mistral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | +| Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | +| Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy | +| Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy | +| Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | +| Qwen2-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | +| Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | + +### Function Signatures + +::: liger_kernel.transformers.apply_liger_kernel_to_llama + options: + extra: + show_docstring: true + show_signature: true + +::: liger_kernel.transformers.apply_liger_kernel_to_mllama + options: + extra: + show_docstring: true + show_signature: true + +::: liger_kernel.transformers.apply_liger_kernel_to_mistral + options: + extra: + show_docstring: true + show_signature: true + +::: liger_kernel.transformers.apply_liger_kernel_to_mixtral + options: + extra: + show_docstring: true + show_signature: true + +::: liger_kernel.transformers.apply_liger_kernel_to_gemma + options: + extra: + show_docstring: true + show_signature: true + +::: liger_kernel.transformers.apply_liger_kernel_to_gemma2 + options: + extra: + show_docstring: true + show_signature: true + +::: liger_kernel.transformers.apply_liger_kernel_to_qwen2 + options: + extra: + show_docstring: true + show_signature: true + +::: liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl + options: + extra: + show_docstring: true + show_signature: true + +::: liger_kernel.transformers.apply_liger_kernel_to_phi3 + options: + extra: + show_docstring: true + show_signature: true diff --git a/docs/Low-Level-APIs.md b/docs/Low-Level-APIs.md new file mode 100755 index 0000000000000000000000000000000000000000..03cfcb0081c9455ecafbeabce2277217b29c9bd2 --- /dev/null +++ b/docs/Low-Level-APIs.md @@ -0,0 +1,133 @@ +## Model Kernels + +| **Kernel** | **API** | +|---------------------------------|-------------------------------------------------------------| +| RMSNorm | `liger_kernel.transformers.LigerRMSNorm` | +| LayerNorm | `liger_kernel.transformers.LigerLayerNorm` | +| RoPE | `liger_kernel.transformers.liger_rotary_pos_emb` | +| SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` | +| GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` | +| CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` | +| Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`| +| Multi Token Attention | `liger_kernel.transformers.LigerMultiTokenAttention` | +| Softmax | `liger_kernel.transformers.LigerSoftmax` | +| Sparsemax | `liger_kernel.transformers.LigerSparsemax` | +| mHC (Hyper-Connections) | `liger_kernel.transformers.LigerMHC` | + + +### RMS Norm + +RMS Norm simplifies the LayerNorm operation by eliminating mean subtraction, which reduces computational complexity while retaining effectiveness. + +This kernel performs normalization by scaling input vectors to have a unit root mean square (RMS) value. This method allows for a ~7x speed improvement and a ~3x reduction in memory footprint compared to +implementations in PyTorch. + +!!! Example "Try it out" + You can experiment as shown in this example [here](https://colab.research.google.com/drive/1CQYhul7MVG5F0gmqTBbx1O1HgolPgF0M?usp=sharing). + +### RoPE + +RoPE (Rotary Position Embedding) enhances the positional encoding used in transformer models. + +The implementation allows for effective handling of positional information without incurring significant computational overhead. + +!!! Example "Try it out" + You can experiment as shown in this example [here](https://colab.research.google.com/drive/1llnAdo0hc9FpxYRRnjih0l066NCp7Ylu?usp=sharing). + +### SwiGLU + +### GeGLU + +### CrossEntropy + +This kernel is optimized for calculating the loss function used in classification tasks. + +The kernel achieves a ~3x execution speed increase and a ~5x reduction in memory usage for substantial vocabulary sizes compared to implementations in PyTorch. + +!!! Example "Try it out" + You can experiment as shown in this example [here](https://colab.research.google.com/drive/1WgaU_cmaxVzx8PcdKB5P9yHB6_WyGd4T?usp=sharing). + +### Fused Linear CrossEntropy + +This kernel combines linear transformations with cross-entropy loss calculations into a single operation. + +!!! Example "Try it out" + You can experiment as shown in this example [here](https://colab.research.google.com/drive/1Z2QtvaIiLm5MWOs7X6ZPS1MN3hcIJFbj?usp=sharing) + +### Multi Token Attention + +The Multi Token Attention kernel implementation provides and optimized fused implementation of multi-token attention over the implemented Pytorch model baseline. This is a new attention mechanism that can operate on multiple Q and K inputs introduced by Meta Research. + +Paper: https://arxiv.org/abs/2504.00927 + +### Softmax + +The Softmax kernel implementation provides an optimized implementation of the softmax operation, which is a fundamental component in neural networks for converting raw scores into probability distributions. + +The implementation shows notable speedups compared to the Softmax PyTorch implementation + + +### Sparsemax + +Sparsemax is a sparse alternative to softmax that produces sparse probability distributions. This kernel implements an efficient version of the sparsemax operation that can be used as a drop-in replacement for softmax in attention mechanisms or classification tasks. + +The implementation achieves significant speed improvements and memory savings compared to standard PyTorch implementations, particularly for large input tensors. + +### mHC (Manifold-Constrained Hyper-Connections) + +mHC implements fused Triton kernels for Manifold-Constrained Hyper-Connections ([arXiv:2512.24880](https://arxiv.org/abs/2512.24880)). It wraps an arbitrary layer `F: [..., C] -> [..., C]` with multiple residual streams, constraining the residual routing matrix `H_res` onto the Birkhoff polytope (doubly-stochastic matrices) via Sinkhorn-Knopp iterations to stabilize training. + +The `LigerMHC` module takes input of shape `[..., HC, C]` where `HC` is the number of residual streams, and performs: + +1. **Coefficients** -- Compute data-dependent routing coefficients (`h_pre`, `h_post`, `h_res`) via fused matmul + RMS normalization + Sinkhorn-Knopp iterations. +2. **Pre-aggregate** -- `x_in = sum_i h_pre[i] * x[i]` (shape: `[..., C]`) +3. **Layer** -- `f_out = layer(x_in)` (shape: `[..., C]`) +4. **Post + residual** -- `x_out[o] = sum_i h_res[o,i] * x[i] + h_post[o] * f_out` (shape: `[..., HC, C]`) + +Usage: + +```python +import torch +import torch.nn as nn +from liger_kernel.transformers import LigerMHC + +# Wrap a linear layer with 4 residual streams of dimension 256 +layer = nn.Linear(256, 256, bias=False, device="cuda", dtype=torch.bfloat16) +mhc = LigerMHC(layer, hc=4, c=256, phi_dtype=torch.bfloat16).cuda() + +# Input: [batch, seq_len, num_streams, channels] in BF16/FP16 +x = torch.randn(2, 128, 4, 256, device="cuda", dtype=torch.bfloat16) +out = mhc(x) # shape: [2, 128, 4, 256] +``` + +Functional APIs are also available: + +- `liger_kernel.transformers.functional.liger_mhc_coeffs` -- Compute routing coefficients +- `liger_kernel.transformers.functional.liger_mhc_pre` -- Pre-aggregation +- `liger_kernel.transformers.functional.liger_mhc_post_res` -- Post-aggregation + residual +- `liger_kernel.transformers.functional.liger_mhc_apply` -- Combined pre + post_res +- `liger_kernel.transformers.functional.liger_mhc_forward` -- Full forward pass (coeffs + pre + layer + post_res) + +## Alignment Kernels + +| **Kernel** | **API** | +|---------------------------------|-------------------------------------------------------------| +| Fused Linear CPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearCPOLoss` | +| Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` | +| Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` | +| Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` | + +## Distillation Kernels + +| **Kernel** | **API** | +|---------------------------------|-------------------------------------------------------------| +| KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` | +| JSD | `liger_kernel.transformers.LigerJSD` | +| Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` | + +## Experimental Kernels + +| **Kernel** | **API** | +|---------------------------------|-------------------------------------------------------------| +| Embedding | `liger_kernel.transformers.experimental.LigerEmbedding` | +| Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul` | \ No newline at end of file diff --git a/docs/acknowledgement.md b/docs/acknowledgement.md new file mode 100755 index 0000000000000000000000000000000000000000..9dfdb4f7b8ba3f96bea71af1e7dc56d121723fab --- /dev/null +++ b/docs/acknowledgement.md @@ -0,0 +1,23 @@ + +### Design + +- [@claire_yishan](https://twitter.com/claire_yishan) for the LOGO design +- [Wave Snippets](https://www.wavesnippets.com/) for generating the animated code snippets + +### Code + +We referenced or used the following projects: + + +| # | Project | Description | Location | License | +|---|----------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------| +| 1 | [Unsloth](https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43) | `calculate_settings` to determine block size and warp; We reuse it for Norm and MLP | [Liger Kernel Utils](https://github.com/linkedin/Liger-Kernel/blob/e249eee723978bf8610ff1ea2297d048a2417e20/src/liger_kernel/ops/utils.py#L23) | [Apache](https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/LICENSE) | +| 2 | [Unsloth](https://github.com/unslothai/unsloth/blob/976d11a10d54383aeb7a692c69e01151a20bfd72/unsloth/kernels/rms_layernorm.py#L48) | We modified and added dW calculation on top of Unsloth implementation | [Liger Kernel RMS Norm](https://github.com/linkedin/Liger-Kernel/blob/e249eee723978bf8610ff1ea2297d048a2417e20/src/liger_kernel/ops/rms_norm.py#L50) | [Apache](https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/LICENSE) | +| 3 | [Triton tutorial](https://triton-lang.org/main/index.html) | We modified on top of triton tutorials | [Liger Kernel RMS Norm](https://github.com/linkedin/Liger-Kernel/blob/e249eee723978bf8610ff1ea2297d048a2417e20/src/liger_kernel/ops/rms_norm.py#L50) | [MIT](https://github.com/triton-lang/triton/blob/main/LICENSE) | +| 4 | [tiny shakespeare dataset](https://huggingface.co/datasets/karpathy/tiny_shakespeare) | We use tiny shakespeare dataset to conduct convergence test on mini model | [Liger Kernel Convergence](https://github.com/linkedin/Liger-Kernel/tree/main/test/convergence) | N/A | +| 5 | [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy) | We use the idea of gradient-in-forward and chunking | [Liger Kernel Linear Cross Entropy](https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/fused_linear_cross_entropy.py) | [MIT](https://github.com/mgmalek/efficient_cross_entropy/blob/main/LICENSE) | +| 6 | [Flash attn](https://github.com/Dao-AILab/flash-attention) | We take many optimization ideas from the work, such as tiling and recomputation | | [BSD](https://github.com/Dao-AILab/flash-attention/blob/main/LICENSE) | +| 7 | [AutoAWQ](https://github.com/casper-hansen/AutoAWQ) | We reference the design of automodel | [Liger Kernel Auto Model](https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/auto_model.py) | [MIT](https://github.com/casper-hansen/AutoAWQ/blob/main/LICENSE) | +| 8 | [llm.c](https://github.com/karpathy/llm.c) | We reference the design of end-to-end testing | [Liger Kernel Convergence Tests](https://github.com/linkedin/Liger-Kernel/tree/main/test/convergence) | [MIT](https://github.com/karpathy/llm.c/blob/master/LICENSE) | + +Many thanks to the contributors to these projects for their invaluable work that helped make Liger possible. diff --git a/docs/contributing.md b/docs/contributing.md new file mode 100755 index 0000000000000000000000000000000000000000..8388f6c432b2ceb82fefd48ddd304452229fe480 --- /dev/null +++ b/docs/contributing.md @@ -0,0 +1,114 @@ + + +Thank you for your interest in contributing to Liger-Kernel! This guide will help you set up your development environment, add a new kernel, run tests, and submit a pull request (PR). + +### Maintainers +@ByronHsu(admin) @qingquansong @yundai424 @kvignesh1420 @lancerts @JasonZhu1313 @shimizust @vaibhavjindal @tcc0403 @momochen + +## Interested in the ticket? + +Leave `#take` in the comment and tag the maintainer. + +## Setting Up Your Development Environment + +1. **Clone the Repository** +```sh +git clone https://github.com/linkedin/Liger-Kernel.git +cd Liger-Kernel +``` +2. **Install Dependencies and Editable Package** +``` +pip install . -e[dev] +``` +If encounter error `no matches found: .[dev]`, please use +``` +pip install -e .'[dev]' +``` +3. **Install pre-commit hooks using [`prek`](https://prek.j178.dev/), a `pre-commit` alternative built in rust** +``` +prek install +``` +Run pre-commit check without committing (`-a` is equivalent to `--all-files`) +``` +prek run -a +``` + +## Structure + +### Source Code +- `ops/`: Core Triton operations. +- `transformers/`: PyTorch `nn.Module` implementations built on Triton operations, compliant with the `transformers` API. + +### Tests + +- `transformers/`: Correctness tests for the Triton-based layers. +- `convergence/`: Patches Hugging Face models with all kernels, runs multiple iterations, and compares weights, logits, and loss layer-by-layer. + +### Benchmark + +- `benchmark/`: Execution time and memory benchmarks compared to Hugging Face layers. + +## Adding support for a new model +To get familiar with the folder structure, please refer [here](https://github.com/linkedin/Liger-Kernel?tab=readme-ov-file#structure.). + +1. **Figure out the kernels that can be monkey-patched** + - Check the `src/liger_kernel/ops` directory to find the kernels that can be monkey-patched. + - Kernels like Fused Linear Cross Entropy require a custom lce_forward function to allow monkey-patching. For adding kernels requiring a similar approach, ensure that you create the corresponding forward function in the `src/liger_kernel/transformers/model` directory. + +2. **Monkey-patch the HuggingFace model** + - Add the monkey-patching code in the `src/liger_kernel/transformers/monkey_patch.py` file. + - Ensure that the monkey-patching function is added to the `__init__.py` file in the `src/liger_kernel/transformers/` directory. + +3. **Add Unit Tests** + - Create unit tests and convergence tests for the monkey-patched model in the tests directory. Ensure that your tests cover all functionalities of the monkey-patched model. + +## Adding a New Kernel +To get familiar with the folder structure, please refer [here](https://github.com/linkedin/Liger-Kernel?tab=readme-ov-file#structure.). + +1. **Create Your Kernel** +Add your kernel implementation in `src/liger_kernel/`. + +2. **Add Unit Tests** +Create unit tests and convergence tests for your kernel in the tests directory. Ensure that your tests cover all kernel functionalities. + +3. **Add Benchmark Script** +Add a benchmarking script under `benchmark/scripts` using the naming convention `benchmark_{kernel_name}.py` showing the performance difference between the Liger kernel and HuggingFace. + +## Run tests + +### Use Makefile to run full tests +1. Run `make test` to ensure correctness. +2. Run `make checkstyle` to ensure code style. +3. Run `make test-convergence` to ensure convergence. + +### Run pytest on single file +`python -m pytest test_sample.py::test_function_name` + +## Run kernel benchmarks +The `/benchmark` directory contains benchmarking scripts for the individual kernels, demonstrating differences in speed and memory usage between using Liger and HuggingFace module implementations. + +1. Run `make run-benchmarks` to run all benchmarking scripts and append data to `benchmark/data/all_benchmark_data.csv`. + - Existing entries that are the same (based on `kernel_name`, `kernel_provider`, `kernel_operation_mode`, `metric_name`, `x_name`, `x_value`, `extra_benchmark_config_str`, and `gpu_name`) will not be overwritten. +2. Run `make run-benchmarks OVERWRITE=1` to overwrite any existing entries that have the same configuration. +3. Run `python benchmark/scripts/benchmark_{kernel_name}.py` to run an individual benchmark. +4. You can use the `benchmark/benchmarks_visualizer.py` script to generate visualizations from the CSV, these are then saved to the `benchmark/visualizations` directory (note: this directory is not tracked by git). + +## Submit PR +Fork the repo, copy and paste the successful test logs in the PR and submit the PR followed by the PR template (**[example PR](https://github.com/linkedin/Liger-Kernel/pull/21)**). + +> As a contributor, you represent that the code you submit is your original work or that of your employer (in which case you represent you have the right to bind your employer). By submitting code, you (and, if applicable, your employer) are licensing the submitted code to LinkedIn and the open source community subject to the BSD 2-Clause license. + +#### Release (Maintainer only) + +1. Bump the version in pyproject.toml to the desired version (for example, `0.2.0`) +2. Submit a PR and merge +3. Create a new release based on the current HEAD, tag name using `v` for example `v0.2.0`. Alternatively, If you want to create release based on a different commit hash, `git tag v0.2.0 && git push origin v0.2.0`, and create release based on this tag +4. Adding release note: Minimum requirement is to click the `Generate Release Notes` button that will automatically generates 1) changes included, 2) new contributors. It's good to add sections on top to highlight the important changes. +5. New pip uploading will be triggered upon a new release. NOTE: Both pre-release and official release will trigger the workflow to build wheel and publish to pypi, so please be sure that step 1-3 are followed correctly! + +### Notes on version +Here we follow the [sematic versioning](https://semver.org/). Denote the version as `major.minor.patch`, we increment: + +- Major version when there is backward incompatible change. +- Minor version when there is new backward-compatible functionality. +- Patch version for bug fixes. diff --git a/docs/images/banner.GIF b/docs/images/banner.GIF new file mode 100755 index 0000000000000000000000000000000000000000..a6a3f63030044efd0dd184a15f4198fab7b085d4 Binary files /dev/null and b/docs/images/banner.GIF differ diff --git a/docs/images/compose.gif b/docs/images/compose.gif new file mode 100755 index 0000000000000000000000000000000000000000..1a7994e536d07d8f2c292e27ca1e0ebfd6a165ff Binary files /dev/null and b/docs/images/compose.gif differ diff --git a/docs/images/e2e-memory.png b/docs/images/e2e-memory.png new file mode 100755 index 0000000000000000000000000000000000000000..ab2f9176055e353199e0bc0ac73e891c8acfe804 Binary files /dev/null and b/docs/images/e2e-memory.png differ diff --git a/docs/images/e2e-tps.png b/docs/images/e2e-tps.png new file mode 100755 index 0000000000000000000000000000000000000000..624ba96d956a92b4612e07edb00a3891328d7c78 Binary files /dev/null and b/docs/images/e2e-tps.png differ diff --git a/docs/images/logo-banner.png b/docs/images/logo-banner.png new file mode 100755 index 0000000000000000000000000000000000000000..fe69d0044269597f78d733cc594fe96c1c23d1d0 Binary files /dev/null and b/docs/images/logo-banner.png differ diff --git a/docs/images/patch.gif b/docs/images/patch.gif new file mode 100755 index 0000000000000000000000000000000000000000..851d239435fffbbf9ad886cc60567b29b854cd1d Binary files /dev/null and b/docs/images/patch.gif differ diff --git a/docs/images/post-training.png b/docs/images/post-training.png new file mode 100755 index 0000000000000000000000000000000000000000..44e33c7be995e28710e2f62d521313cacace362b Binary files /dev/null and b/docs/images/post-training.png differ diff --git a/docs/index.md b/docs/index.md new file mode 100755 index 0000000000000000000000000000000000000000..4342cdcd616cdfe55f0d982fb0a7e220b1bb3038 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,186 @@ + + +# Liger Kernel: Efficient Triton Kernels for LLM Training + + + + + + + + + + + + + + + + + +
StableNightlyDiscordBuild
+ + Downloads (Stable) + + + + PyPI - Version + + + + Downloads (Nightly) + + + + PyPI - Version + + + + Join Our Discord + + +
+ + Build + +
+
+ + Build + +
+
+ + + + + + +**Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training. + +We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655). + +## Supercharge Your Model with Liger Kernel + +With one line of code, Liger Kernel can increase throughput by more than 20% and reduce memory usage by 60%, thereby enabling longer context lengths, larger batch sizes, and massive vocabularies. + + +| Speed Up | Memory Reduction | +|--------------------------|-------------------------| +| ![Speed up](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-tps.png) | ![Memory](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-memory.png) | + +> **Note:** +> - Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s. +> - Hugging Face models start to OOM at a 4K context length, whereas Hugging Face + Liger Kernel scales up to 16K. + +## Optimize Post Training with Liger Kernel + +

+ Post Training +

+ +We provide optimized post training kernels like DPO, ORPO, SimPO, and more which can reduce memory usage by up to 80%. You can easily use them as python modules. + +```python +from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss +orpo_loss = LigerFusedLinearORPOLoss() +y = orpo_loss(lm_head.weight, x, target) +``` + +#### Key Features + +- **Ease of use:** Simply patch your Hugging Face model with one line of code, or compose your own model using our Liger Kernel modules. +- **Time and memory efficient:** In the same spirit as Flash-Attn, but for layers like **RMSNorm**, **RoPE**, **SwiGLU**, and **CrossEntropy**! Increases multi-GPU training throughput by 20% and reduces memory usage by 60% with **kernel fusion**, **in-place replacement**, and **chunking** techniques. +- **Exact:** Computation is exact—no approximations! Both forward and backward passes are implemented with rigorous unit tests and undergo convergence testing against training runs without Liger Kernel to ensure accuracy. +- **Lightweight:** Liger Kernel has minimal dependencies, requiring only Torch and Triton—no extra libraries needed! Say goodbye to dependency headaches! +- **Multi-GPU supported:** Compatible with multi-GPU setups (PyTorch FSDP, DeepSpeed, DDP, etc.). +- **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860), [SWIFT](https://github.com/modelscope/ms-swift) + +### Installation + +To install the stable version: + +```bash +$ pip install liger-kernel +``` + +To install the nightly version: + +```bash +$ pip install liger-kernel-nightly +``` + +To install from source: + +```bash +git clone https://github.com/linkedin/Liger-Kernel.git +cd Liger-Kernel + +# Install Default Dependencies +# Setup.py will detect whether you are using AMD or NVIDIA +pip install -e . + +# Setup Development Dependencies +pip install -e ".[dev]" +``` + +!!! Note " Dependencies " + + #### CUDA + + - `torch >= 2.1.2` + - `triton >= 2.3.0` + + #### ROCm + + - `torch >= 2.5.0` Install according to the instruction in Pytorch official webpage. + - `triton >= 3.0.0` Install from pypi. (e.g. `pip install triton==3.0.0`) + +!!!Tip "Optional Dependencies " + + - `transformers >= 4.x`: Required if you plan to use the transformers models patching APIs. The specific model you are working will dictate the minimum version of transformers. + +!!! Note + Our kernels inherit the full spectrum of hardware compatibility offered by [Triton](https://github.com/triton-lang/triton). + + +#### Sponsorship and Collaboration + +- [AMD](https://www.amd.com/en.html): Providing AMD GPUs for our AMD CI. +- [Intel](https://www.intel.com/): Providing Intel GPUs for our Intel CI. +- [Modal](https://modal.com/): Free 3000 credits from GPU MODE IRL for our NVIDIA CI. +- [EmbeddedLLM](https://embeddedllm.com/): Making Liger Kernel run fast and stable on AMD. +- [HuggingFace](https://huggingface.co/): Integrating Liger Kernel into Hugging Face Transformers and TRL. +- [Lightning AI](https://lightning.ai/): Integrating Liger Kernel into Lightning Thunder. +- [Axolotl](https://axolotl.ai/): Integrating Liger Kernel into Axolotl. +- [Llama-Factory](https://github.com/hiyouga/LLaMA-Factory): Integrating Liger Kernel into Llama-Factory. + + +!!! Note " Contact " + + - For issues, create a Github ticket in this repository . + - For open discussion, join [our discord channel](https://discord.gg/gpumode) . + - For formal collaboration, send an email to byhsu@linkedin.com . + +### Cite this work + +Bib Latex entry: +```bib +@inproceedings{ +hsu2025ligerkernel, +title={Liger-Kernel: Efficient Triton Kernels for {LLM} Training}, +author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen and Zhipeng Wang}, +booktitle={Championing Open-source DEvelopment in ML Workshop @ ICML25}, +year={2025}, +url={https://openreview.net/forum?id=36SjAIT42G} +} +``` + +### Star History +[![Star History Chart](https://api.star-history.com/svg?repos=linkedin/Liger-Kernel&type=Date)](https://star-history.com/#linkedin/Liger-Kernel&Date) + +

+ + ↑ Back to Top ↑ + +

diff --git a/docs/license.md b/docs/license.md new file mode 100755 index 0000000000000000000000000000000000000000..53e5e7d25e9487689904dcf12a22297d9b4e85a9 --- /dev/null +++ b/docs/license.md @@ -0,0 +1,8 @@ +This project is licensed under the [BSD 2-CLAUSE](https://github.com/linkedin/Liger-Kernel/blob/main/LICENSE) License (see `LICENSE` for details). +It also includes components from projects licensed under: + +- Apache License 2.0 (see `LICENSE-APACHE-2.0` for details). +- MIT License (see `LICENSE-MIT-AutoAWQ` for details). +- MIT License (see `LICENSE-MIT-Efficient Cross Entropy` for details). +- MIT License (see `LICENSE-MIT-llmc` for details). +- MIT License (see `LICENSE-MIT-triton` for details). \ No newline at end of file diff --git a/examples/alignment/accelerate_config.yaml b/examples/alignment/accelerate_config.yaml new file mode 100755 index 0000000000000000000000000000000000000000..e70f3cdcf7744b6a542583ad87fc67bbe95d835e --- /dev/null +++ b/examples/alignment/accelerate_config.yaml @@ -0,0 +1,26 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +enable_cpu_affinity: false +fsdp_config: + fsdp_activation_checkpointing: false + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_cpu_ram_efficient_loading: true + fsdp_forward_prefetch: false + fsdp_offload_params: false + fsdp_sharding_strategy: FULL_SHARD + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_sync_module_states: true + fsdp_use_orig_params: true +machine_rank: 0 +main_training_function: main +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/alignment/run_orpo.py b/examples/alignment/run_orpo.py new file mode 100755 index 0000000000000000000000000000000000000000..7dc9450c0160a70dc595c910529dfcfc265c3943 --- /dev/null +++ b/examples/alignment/run_orpo.py @@ -0,0 +1,35 @@ +import torch + +from datasets import load_dataset +from transformers import AutoModelForCausalLM +from transformers import AutoTokenizer +from trl import ORPOConfig # noqa: F401 + +from liger_kernel.transformers.trainer import LigerORPOTrainer # noqa: F401 + +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-3.2-1B-Instruct", + dtype=torch.bfloat16, +) + +tokenizer = AutoTokenizer.from_pretrained( + "meta-llama/Llama-3.2-1B-Instruct", + max_length=512, + padding="max_length", +) +tokenizer.pad_token = tokenizer.eos_token + +train_dataset = load_dataset("trl-lib/tldr-preference", split="train") + +training_args = ORPOConfig( + output_dir="Llama3.2_1B_Instruct", + beta=0.1, + max_length=128, + per_device_train_batch_size=32, + max_steps=100, + save_strategy="no", +) + +trainer = LigerORPOTrainer(model=model, args=training_args, tokenizer=tokenizer, train_dataset=train_dataset) + +trainer.train() diff --git a/examples/huggingface/README.md b/examples/huggingface/README.md new file mode 100755 index 0000000000000000000000000000000000000000..41de0dcb737ae912d3935e10d23f23f9ebe212e5 --- /dev/null +++ b/examples/huggingface/README.md @@ -0,0 +1,55 @@ +# Liger-Kernel Example with HuggingFace Trainer + +## How to Run + +### Locally on a GPU machine +You can run the example locally on a GPU machine. The default hyperparameters and configurations work on single node with 4xA100 80GB GPUs. + +```bash +pip install -r requirements.txt +sh run_{MODEL}.sh +``` + +### Remotely on Modal +If you do not have access to a GPU machine, you can run the example on Modal. Modal is a serverless platform that allows you to run your code on a remote GPU machine. You can sign up for a free account at [Modal](https://www.modal.com/). + +```bash +pip install modal +modal setup # authenticate with Modal +modal run launch_on_modal.py --script "run_qwen2_vl.sh" +``` + +**Notes** +1. This example uses an optional `use_liger` flag. If true, it does a 1 line monkey patch to apply liger kernel. +2. The example uses Llama3 model that requires community license agreement and HuggingFace Hub login. If you want to use Llama3 in this example, please make sure you have done the followings: + * Agree on the community license agreement https://huggingface.co/meta-llama/Meta-Llama-3-8B + * Run `huggingface-cli login` and enter your HuggingFace token +3. The default hyperparameters and configurations work on single node with 4xA100 80GB GPUs. For running on device with less GPU RAM, please consider reducing the per-GPU batch size and/or enable `CPUOffload` in FSDP. + + +## Benchmark Result + +### LLaMA +Benchmark conditions: LLaMA 3-8B, Alpaca Dataset, Max seq len = 512, Data Type = bf16, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 4 A100s. + +Throughput improves by around 20%, while GPU memory usage drops by 40%. This allows you to train the model on smaller GPUs, use larger batch sizes, or handle longer sequence lengths without incurring additional costs. + +![Throughput](img/llama_tps.png) +![GPU Memory Allocated](img/llama_mem_alloc.png) + +### QWEN +Benchmark conditions: Qwen2-7B, Alpaca Dataset, Max seq len = 512, Data Type = bf16, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 4 A100s. + +Throughput improves by around 10%, while GPU memory usage drops by 50%. + +![Throughput](img/qwen_tps.png) +![GPU Memory Allocated](img/qwen_mem_alloc.png) + + +### GEMMA 7B +Benchmark conditions: Gemma-7B, Alpaca Dataset, Max seq len = 512, Data Type = bf16, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 4 A100s. + +Throughput improves by around 24%, while GPU memory usage drops by 33%. + +![Throughput](img/gemma_7b_mem.png) +![GPU Memory Allocated](img/gemma_7b_tp.png) diff --git a/examples/huggingface/callback.py b/examples/huggingface/callback.py new file mode 100755 index 0000000000000000000000000000000000000000..c834fc56634a4444fb27edd2b1f9a27f061c0001 --- /dev/null +++ b/examples/huggingface/callback.py @@ -0,0 +1,257 @@ +import time + +from dataclasses import dataclass + +import torch +import transformers + +from transformers import TrainerControl +from transformers import TrainerState +from transformers import TrainingArguments + +from liger_kernel.utils import infer_device + +# https://simple.wikipedia.org/wiki/Byte +# For memory, we use binary system +M_BIN_UNIT = 2**20 +# For metrics (tflops), we use decimal system +T_DEC_UNIT = 10**12 + + +def round_to_n_decimal(x, n): + return round(x, n) + + +@dataclass +class Precision: + """ + Precision is a dataclass to store the number of decimal points for each metric. + """ + + n_decimal_time: int + n_decimal_memory: int + n_decimal_TPS: int + + +@dataclass +class State: + """ + State is a dataclass to store the internal state of the efficiency callback. + """ + + n_warmup_steps: int = 0 + total_peak_memory_allocated: float = float("-inf") + total_peak_memory_reserved: float = float("-inf") + + step_start_time: float = 0.0 + elapsed_time: float = 0.0 + + elapsed_step: int = 0 + + step_start_tokens_seen: int = 0 + elapsed_tokens_seen: int = 0 + + global_start_step: int = 0 + + +@dataclass +class Time: + """ + Time is a dataclass to store the time-related metrics. + """ + + step: int = 0 + step_time_sec: float = 0.0 + avg_step_time_sec: float = 0.0 + time_to_completion_sec: float = 0.0 + estimated_total_time_sec: float = 0.0 + + +@dataclass +class Memory: + """ + Memory is a dataclass to store the memory-related metrics. + """ + + step_peak_memory_allocated_MB: float = 0.0 + step_peak_memory_reserved_MB: float = 0.0 + total_peak_memory_allocated_MB: float = 0.0 + total_peak_memory_reserved_MB: float = 0.0 + + +@dataclass +class TPS: + """ + TPS is a dataclass to store the tokens per second metrics. + """ + + step_tokens_per_second: float = 0.0 + avg_tokens_per_second: float = 0.0 + + +class EfficiencyCallback(transformers.TrainerCallback): + """ + EfficiencyCallback is a callback to track the efficiency of the training process. + The tracked stats include: step time, memory, and throughput. + + It requires including `--include_num_input_tokens_seen` and `logging_steps=1` in the training arguments. + + Args: + n_warmup_steps: number of warmup steps + The stats in the first n_warmup_steps will not be added into the aggregated stats + This is because the first few steps might take longer due to jit compliation and other initialization overheads + n_decimal_time: number of decimal points for time + n_decimal_memory: number of decimal points for memory + n_decimal_TPS: number of decimal points for TPS + """ + + def __init__(self, n_warmup_steps=2, n_decimal_time=2, n_decimal_memory=2, n_decimal_TPS=2): + self.state = State( + n_warmup_steps, + ) + + self.precision = Precision(n_decimal_time, n_decimal_memory, n_decimal_TPS) + + self.time = Time() + self.memory = Memory() + self.tps = TPS() + self.device = infer_device() + + def on_init_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """ + Event called at the end of the initialization of the [`Trainer`]. + """ + if not args.include_num_input_tokens_seen: + raise Exception( + 'Please pass training argument "--include_num_input_tokens_seen" to track tokens per second' + ) + if args.logging_steps != 1: + raise Exception("Please set logging_steps=1 to track the efficiency metrics accurately") + + def on_train_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + # if loaded from checkpoints, global_start_step is not 1 but state.global_step + self.state.global_start_step = state.global_step + + def on_log( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + logs: dict[str, float], + **kwargs, + ): + if state.global_step < (self.state.global_start_step + self.state.n_warmup_steps): + return + else: + # spread self.time, self.memory, self.tps to logs + logs.update(self.time.__dict__) + logs.update(self.memory.__dict__) + logs.update(self.tps.__dict__) + + def on_step_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """ + Event called at the beginning of a training step. If using gradient accumulation, one training step might take + several inputs. + """ + # memory + getattr(torch, self.device).reset_peak_memory_stats() + + # time + self.state.step_start_time = time.perf_counter() + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + if state.global_step < (self.state.global_start_step + self.state.n_warmup_steps): + # The end the current step_start_tokens_seen is the start of next iteration + + # tokens + self.state.step_start_tokens_seen = state.num_input_tokens_seen + return + + # time + current_time = time.perf_counter() + step_time = current_time - self.state.step_start_time + self.state.elapsed_time += step_time + + # step + global_step = state.global_step + self.state.elapsed_step += 1 + avg_step_time = self.state.elapsed_time / self.state.elapsed_step + + self.time.step = global_step + self.time.step_time_sec = round_to_n_decimal(step_time, self.precision.n_decimal_time) + self.time.avg_step_time_sec = round_to_n_decimal(avg_step_time, self.precision.n_decimal_time) + self.time.time_to_completion_sec = round_to_n_decimal( + avg_step_time * (state.max_steps - global_step), + self.precision.n_decimal_time, + ) + self.time.estimated_total_time_sec = round_to_n_decimal( + avg_step_time * state.max_steps, self.precision.n_decimal_time + ) + + # memory + step_peak_memory_allocated = getattr(torch, self.device).memory.max_memory_allocated() + step_peak_memory_reserved = getattr(torch, self.device).memory.max_memory_reserved() + + self.memory.step_peak_memory_allocated_MB = round_to_n_decimal( + step_peak_memory_allocated / M_BIN_UNIT, self.precision.n_decimal_memory + ) + self.state.total_peak_memory_allocated = max(self.state.total_peak_memory_allocated, step_peak_memory_allocated) + self.memory.total_peak_memory_allocated_MB = round_to_n_decimal( + self.state.total_peak_memory_allocated / M_BIN_UNIT, + self.precision.n_decimal_memory, + ) + + self.memory.step_peak_memory_reserved_MB = round_to_n_decimal( + step_peak_memory_reserved / M_BIN_UNIT, self.precision.n_decimal_memory + ) + + self.state.total_peak_memory_reserved = max(self.state.total_peak_memory_reserved, step_peak_memory_reserved) + + self.memory.total_peak_memory_reserved_MB = round_to_n_decimal( + self.state.total_peak_memory_reserved / M_BIN_UNIT, + self.precision.n_decimal_memory, + ) + + # tokens + step_tokens_seen = state.num_input_tokens_seen - self.state.step_start_tokens_seen + + self.state.elapsed_tokens_seen += step_tokens_seen + + self.tps.step_tokens_per_second = round_to_n_decimal( + step_tokens_seen / step_time, + self.precision.n_decimal_TPS, + ) + + self.tps.avg_tokens_per_second = round_to_n_decimal( + self.state.elapsed_tokens_seen / self.state.elapsed_time, + self.precision.n_decimal_TPS, + ) + + # The end the current step_start_tokens_seen is the start of next iteration + + # tokens + self.state.step_start_tokens_seen = state.num_input_tokens_seen diff --git a/examples/huggingface/config/fsdp_config.json b/examples/huggingface/config/fsdp_config.json new file mode 100755 index 0000000000000000000000000000000000000000..45894b0b50bf5f843837ebe2f07d199f5e8a8df0 --- /dev/null +++ b/examples/huggingface/config/fsdp_config.json @@ -0,0 +1,5 @@ +{ + "backward_prefetch": "backward_pre", + "forward_prefetch": "true", + "activation_checkpointing": true +} \ No newline at end of file diff --git a/examples/huggingface/img/gemma_7b_mem.png b/examples/huggingface/img/gemma_7b_mem.png new file mode 100755 index 0000000000000000000000000000000000000000..940d0918bd2c74aaff45666b5d5b4f4778f3fb56 Binary files /dev/null and b/examples/huggingface/img/gemma_7b_mem.png differ diff --git a/examples/huggingface/img/gemma_7b_tp.png b/examples/huggingface/img/gemma_7b_tp.png new file mode 100755 index 0000000000000000000000000000000000000000..7163543df49e0fe615e6f881f7cb4209125146aa Binary files /dev/null and b/examples/huggingface/img/gemma_7b_tp.png differ diff --git a/examples/huggingface/img/llama_mem_alloc.png b/examples/huggingface/img/llama_mem_alloc.png new file mode 100755 index 0000000000000000000000000000000000000000..8f89581e5c0c7f2838aa1a9a8bedb05789fe3e18 Binary files /dev/null and b/examples/huggingface/img/llama_mem_alloc.png differ diff --git a/examples/huggingface/img/llama_tps.png b/examples/huggingface/img/llama_tps.png new file mode 100755 index 0000000000000000000000000000000000000000..37dd35a3ee417ee213089c308c271748695a808a Binary files /dev/null and b/examples/huggingface/img/llama_tps.png differ diff --git a/examples/huggingface/img/qwen_mem_alloc.png b/examples/huggingface/img/qwen_mem_alloc.png new file mode 100755 index 0000000000000000000000000000000000000000..9f4154bbb4bfe815938ebad940c8830c95b69891 Binary files /dev/null and b/examples/huggingface/img/qwen_mem_alloc.png differ diff --git a/examples/huggingface/img/qwen_tps.png b/examples/huggingface/img/qwen_tps.png new file mode 100755 index 0000000000000000000000000000000000000000..cbc86c8f41fddc118861541158645b632f2825ae Binary files /dev/null and b/examples/huggingface/img/qwen_tps.png differ diff --git a/examples/huggingface/launch_on_modal.py b/examples/huggingface/launch_on_modal.py new file mode 100755 index 0000000000000000000000000000000000000000..1171ea42d94f55fbdbd954ba44b3a4447639d800 --- /dev/null +++ b/examples/huggingface/launch_on_modal.py @@ -0,0 +1,69 @@ +""" +launch_on_modal.py + +This tool is designed to launch scripts using Modal. + +It sets up the necessary environment, including GPU resources and python dependencies, +and executes the specified training script remotely. + +### Setup and Usage +```bash +pip install modal +modal setup # authenticate with Modal +export HF_TOKEN="your_huggingface_token" # if using a gated model such as llama3 +modal run launch_on_modal.py --script "run_qwen2_vl.sh" +``` + +### Caveats +This tool is intended as an easy on-ramp to using Liger-Kernel for fine-tuning LLMs and +VLMs - it is a reproducible way to run benchmarks and example scripts. However, it is not +the best way to develop a model on Modal, as it re-downloads the model and dataset each +time it is run. For iterative development, consider using `modal.Volume` to cache the +model and dataset between runs. +""" + +import os + +import modal + +from modal import gpu + +TWO_HOURS = 2 * 60 * 60 +SIXTEEN_GB = 16 * 1024 + +app = modal.App("liger-example") + +image = modal.Image.debian_slim().pip_install_from_requirements("requirements.txt").copy_local_dir(".", "/root") + +if "HF_TOKEN" not in os.environ: + print("HF_TOKEN not found in environment variables, using an empty token.") +hf_token_secret = modal.Secret.from_dict({"HF_TOKEN": os.environ.get("HF_TOKEN", "")}) + + +@app.function( + gpu=gpu.A100(count=4, size="80GB"), + image=image, + timeout=TWO_HOURS, + memory=SIXTEEN_GB, + secrets=[hf_token_secret], +) +def launch_script(script: str): + import subprocess + + script_path = f"/root/{script}" + os.chmod(script_path, 0o755) # make script executable + + print(f"Running script: {script_path}") + subprocess.run([script_path], check=True, cwd="/root", env=os.environ.copy()) + + +@app.local_entrypoint() +def main(script: str): + """ + Launch a script remotely on modal. + ```bash + export HF_TOKEN="your_huggingface_token" # if using a gated model such as llama3 + modal run --detach launch_on_modal.py --script "run_qwen2_vl.sh" + ``` + """ + launch_script.remote(script=script) diff --git a/examples/huggingface/requirements.txt b/examples/huggingface/requirements.txt new file mode 100755 index 0000000000000000000000000000000000000000..d6d10e9ecd64b0f4e3ed93b43aea09e8b154a6fa --- /dev/null +++ b/examples/huggingface/requirements.txt @@ -0,0 +1,6 @@ +transformers==4.45.2 +trl +liger-kernel +triton +torch +torchvision \ No newline at end of file diff --git a/examples/huggingface/run_benchmarks.sh b/examples/huggingface/run_benchmarks.sh new file mode 100755 index 0000000000000000000000000000000000000000..cf4234aeaa4d6ad56c742f23657dba14140b0138 --- /dev/null +++ b/examples/huggingface/run_benchmarks.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +## Benchmarking Script +## Runs the training script with different configurations and logs the results + +MODEL_TYPE="mistral" +MODEL_PATH="mistralai/Mistral-7B-v0.1" +USE_LIGER_VALUES=("True" "False") +BATCH_SIZE_VALUES=(64 128 192) +NUM_REP=5 +MAX_STEPS=20 +DATASET_PATH="tatsu-lab/alpaca" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +mkdir -p "${SCRIPT_DIR}/results" + +for USE_LIGER in "${USE_LIGER_VALUES[@]}"; do + for BATCH_SIZE in "${BATCH_SIZE_VALUES[@]}"; do + echo "Running with use_liger=$USE_LIGER and batch_size=$BATCH_SIZE" + + for ((i=1; i<=NUM_REP; i++)); do + + LOG_FILE="${SCRIPT_DIR}/results/${MODEL_TYPE}_use_liger_${USE_LIGER}_batch_size_${BATCH_SIZE}_rep_${i}.log" + + torchrun --nnodes=1 --nproc-per-node=4 training.py \ + --bf16 \ + --num_train_epochs 1 \ + --max_steps $MAX_STEPS \ + --model_name $MODEL_PATH \ + --dataset $DATASET_PATH \ + --per_device_train_batch_size $BATCH_SIZE \ + --per_device_eval_batch_size 16 \ + --eval_strategy "no" \ + --save_strategy "no" \ + --learning_rate 6e-6 \ + --weight_decay 0.05 \ + --warmup_ratio 0.1 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --include_num_input_tokens_seen \ + --report_to none \ + --fsdp "full_shard auto_wrap" \ + --fsdp_config config/fsdp_config.json \ + --seed 42 \ + --use_liger $USE_LIGER \ + --output_dir model_output_dir \ + > $LOG_FILE + + sleep 5 + done + done +done \ No newline at end of file diff --git a/examples/huggingface/run_gemma.sh b/examples/huggingface/run_gemma.sh new file mode 100755 index 0000000000000000000000000000000000000000..c882f5e7f720f32f34c3d8fdad92534c6da16411 --- /dev/null +++ b/examples/huggingface/run_gemma.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +torchrun --nnodes=1 --nproc-per-node=4 training.py \ + --model_name "google/gemma-7b-it" \ + --bf16 \ + --max_steps 20 \ + --per_device_train_batch_size 24 \ + --per_device_eval_batch_size 1 \ + --eval_strategy "no" \ + --save_strategy "no" \ + --learning_rate 6e-6 \ + --weight_decay 0.05 \ + --warmup_ratio 0.1 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --include_num_input_tokens_seen \ + --report_to none \ + --fsdp "full_shard auto_wrap" \ + --fsdp_config config/fsdp_config.json \ + --seed 42 \ + --use_liger True \ + --output_dir alpaca_finetuning diff --git a/examples/huggingface/run_llama.sh b/examples/huggingface/run_llama.sh new file mode 100755 index 0000000000000000000000000000000000000000..b6a1fc73f74572ee56f4850e74021ea958e8843b --- /dev/null +++ b/examples/huggingface/run_llama.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +torchrun --nnodes=1 --nproc-per-node=4 training.py \ + --bf16 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 64 \ + --per_device_eval_batch_size 64 \ + --eval_strategy "no" \ + --save_strategy "no" \ + --learning_rate 6e-6 \ + --weight_decay 0.05 \ + --warmup_ratio 0.1 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --include_num_input_tokens_seen \ + --report_to none \ + --fsdp "full_shard auto_wrap" \ + --fsdp_config config/fsdp_config.json \ + --seed 42 \ + --use_liger True \ + --output_dir alpaca_finetuning diff --git a/examples/huggingface/run_qwen.sh b/examples/huggingface/run_qwen.sh new file mode 100755 index 0000000000000000000000000000000000000000..54a157fbc265a36e75965f65e3686937d9eb1485 --- /dev/null +++ b/examples/huggingface/run_qwen.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +torchrun --nnodes=1 --nproc-per-node=4 training.py \ + --model_name "Qwen/Qwen2-7B" \ + --bf16 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 48 \ + --per_device_eval_batch_size 64 \ + --eval_strategy "no" \ + --save_strategy "no" \ + --learning_rate 6e-6 \ + --weight_decay 0.05 \ + --warmup_ratio 0.1 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --include_num_input_tokens_seen \ + --report_to none \ + --fsdp "full_shard auto_wrap" \ + --fsdp_config config/fsdp_config.json \ + --seed 42 \ + --use_liger True \ + --output_dir alpaca_finetuning diff --git a/examples/huggingface/run_qwen2_vl.sh b/examples/huggingface/run_qwen2_vl.sh new file mode 100755 index 0000000000000000000000000000000000000000..963600f0120f0a9da95da417ca87c7cbf4e4f3ff --- /dev/null +++ b/examples/huggingface/run_qwen2_vl.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +torchrun --nnodes=1 --nproc-per-node=4 training_multimodal.py \ + --model_name "Qwen/Qwen2-VL-7B-Instruct" \ + --bf16 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 8 \ + --per_device_eval_batch_size 8 \ + --eval_strategy "no" \ + --save_strategy "no" \ + --learning_rate 6e-6 \ + --weight_decay 0.05 \ + --warmup_ratio 0.1 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --include_num_input_tokens_seen \ + --report_to none \ + --fsdp "full_shard auto_wrap" \ + --fsdp_config config/fsdp_config.json \ + --seed 42 \ + --use_liger True \ + --output_dir multimodal_finetuning diff --git a/examples/huggingface/training.py b/examples/huggingface/training.py new file mode 100755 index 0000000000000000000000000000000000000000..d431b10111215a87c636c80da92763f329d2e6a9 --- /dev/null +++ b/examples/huggingface/training.py @@ -0,0 +1,79 @@ +from dataclasses import dataclass + +import datasets +import torch +import transformers + +from callback import EfficiencyCallback +from trl import DataCollatorForCompletionOnlyLM +from trl import SFTTrainer + +from liger_kernel.transformers import AutoLigerKernelForCausalLM + + +@dataclass +class CustomArguments: + model_name: str = "meta-llama/Meta-Llama-3-8B" + dataset: str = "tatsu-lab/alpaca" + max_seq_length: int = 512 + use_liger: bool = False + + +def formatting_prompts_func(example): + return example["text"] + + +def train(): + parser = transformers.HfArgumentParser((transformers.TrainingArguments, CustomArguments)) + training_args, custom_args = parser.parse_args_into_dataclasses() + tokenizer = transformers.AutoTokenizer.from_pretrained( + custom_args.model_name, + padding_side="left", + truncation_side="left", + ) + tokenizer.pad_token = tokenizer.eos_token + + dataset = datasets.load_dataset(custom_args.dataset)["train"].train_test_split(test_size=0.1) + train_dataset = dataset["train"] + eval_dataset = dataset["test"] + response_prompt = tokenizer.encode("### Response:\n", add_special_tokens=False) + collator = DataCollatorForCompletionOnlyLM( + tokenizer=tokenizer, + response_template=response_prompt, + pad_to_multiple_of=16, + ) + + if custom_args.use_liger: + model = AutoLigerKernelForCausalLM.from_pretrained( + custom_args.model_name, + trust_remote_code=True, + use_cache=False, + dtype=torch.bfloat16, + # These args will get passed to the appropriate apply_liger_kernel_to_* function + # to override the default settings + # cross_entropy=True, + # fused_linear_cross_entropy=False, + ) + else: + model = transformers.AutoModelForCausalLM.from_pretrained( + custom_args.model_name, + trust_remote_code=True, + use_cache=False, + dtype=torch.bfloat16, + ) + + trainer = SFTTrainer( + model=model, + args=training_args, + data_collator=collator, + max_seq_length=custom_args.max_seq_length, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + formatting_func=formatting_prompts_func, + callbacks=[EfficiencyCallback()], + ) + trainer.train() + + +if __name__ == "__main__": + train() diff --git a/examples/huggingface/training_multimodal.py b/examples/huggingface/training_multimodal.py new file mode 100755 index 0000000000000000000000000000000000000000..1fcee87da2d79abccdb8b7608be6a8f57438c372 --- /dev/null +++ b/examples/huggingface/training_multimodal.py @@ -0,0 +1,169 @@ +import os + +from dataclasses import dataclass + +import datasets +import torch +import transformers + +from callback import EfficiencyCallback +from datasets import Image as ImageFeature +from trl import SFTTrainer + +from liger_kernel.transformers import monkey_patch + + +@dataclass +class CustomArguments: + model_name: str = "Qwen/Qwen2-VL-2B-Instruct" + dataset: str = "HuggingFaceM4/the_cauldron" + dataset_subset: str = "ai2d" + dataset_split: str = "train" + max_seq_length: int = 512 + dataset_text_field: str = "texts" + use_liger: bool = False + + +def construct_model_and_processor(model_name: str, use_liger: bool) -> torch.nn.Module: + if "Qwen2-VL" in model_name: + from transformers import Qwen2VLForConditionalGeneration + + # These settings are used to reduce the memory footprint of the Qwen2-VL model, + # which supports training/inferences on images in their native resolution. Large + # images -> many visual tokens (a max of 16384) -> large memory consumption. + # If fine-tuning for a real-world application, consider these values carefully. + min_visual_tokens_per_image = 256 + max_visual_tokens_per_image = 256 + + processor = transformers.AutoProcessor.from_pretrained( + model_name, + padding_side="left", + truncation_side="left", + min_pixels=min_visual_tokens_per_image * 28 * 28, # patch size is 14x14 + max_pixels=max_visual_tokens_per_image * 28 * 28, # 4 patches / token + ) + processor.tokenizer.pad_token = processor.tokenizer.eos_token + image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") + + if use_liger: + print("Applying Liger Kernel to Qwen2-VL model") + monkey_patch.apply_liger_kernel_to_qwen2_vl( + # These args can be used to override the default Liger settings + # cross_entropy=True, + # fused_linear_cross_entropy=False, + ) + + model = Qwen2VLForConditionalGeneration.from_pretrained( + pretrained_model_name_or_path=model_name, + use_cache=False, + dtype=torch.bfloat16, + low_cpu_mem_usage=True, + attn_implementation="sdpa", + ) + return model, processor, image_token_id + + raise NotImplementedError(f"Model {model_name} not supported") + + +def _validate_and_extract_the_cauldron(examples) -> dict[str, list]: + batch_texts = [] + batch_images = [] + for images, texts in zip(examples["images"], examples["texts"]): + if not images: + raise ValueError("No image found in example from the_cauldron dataset") + if len(images) > 1: + raise ValueError("Only one image per example is supported") + batch_texts.extend(texts) + batch_images.extend([images[0]] * len(texts)) + return {"texts": batch_texts, "images": batch_images} + + +def _format_for_convo(example, tokenizer): + # cauldron data is already in message format {"user": ..., "assistant": ...} + text = example["texts"] + messages = [ + { + "role": "user", + "content": [{"type": "image"}, {"type": "text", "text": text["user"]}], + }, + {"role": "assistant", "content": [{"type": "text", "text": text["assistant"]}]}, + ] + text = tokenizer.apply_chat_template(messages, tokenize=False) + return {"texts": text} + + +def train(): + parser = transformers.HfArgumentParser((transformers.TrainingArguments, CustomArguments)) + training_args, custom_args = parser.parse_args_into_dataclasses() + training_args.remove_unused_columns = False # required to not drop the image column + training_args.dataset_kwargs = {"skip_prepare_dataset": True} + + model, processor, image_token_id = construct_model_and_processor(custom_args.model_name, custom_args.use_liger) + + dataset = ( + datasets.load_dataset( + custom_args.dataset, + custom_args.dataset_subset, + split=custom_args.dataset_split, + ) + .map( + _validate_and_extract_the_cauldron, + batched=True, + num_proc=min(os.cpu_count(), 16), + desc="Extracting text and images", + ) + .map( + _format_for_convo, + fn_kwargs={"tokenizer": processor.tokenizer}, + desc="Formatting for convo", + ) + .cast_column("images", ImageFeature()) + .train_test_split(test_size=0.1) + ) + + train_dataset = dataset["train"] + eval_dataset = dataset["test"] + + def collate_fn(examples): + """ + Taken directly from the TRL documentation with minor modifications: + https://huggingface.co/docs/trl/en/sft_trainer#a-custom-collator-for-processing-multi-modal-data + + Modifications: + 1. `apply_chat_template` is used to preprocess the texts before training begins (see above) + 2. `example["messages"]` -> `example["texts"]` to conform with the_cauldron dataset schema + 3. Ignoring image tokens in the loss computation + """ + # Get the texts and images + texts = [example["texts"] for example in examples] + images = [example["images"] for example in examples] + + # Tokenize the texts and process the images + batch = processor(text=texts, images=images, return_tensors="pt", padding=True) + + # The labels are the input_ids, and we mask the padding tokens in the loss computation + labels = batch["input_ids"].clone() + labels[labels == processor.tokenizer.pad_token_id] = -100 + + # Ignore the image token index in the loss computation + labels[labels == image_token_id] = -100 + batch["labels"] = labels + + return batch + + trainer = SFTTrainer( + model=model, + args=training_args, + data_collator=collate_fn, + max_seq_length=custom_args.max_seq_length, + dataset_text_field=custom_args.dataset_text_field, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=processor.tokenizer, + callbacks=[EfficiencyCallback()], + ) + trainer.train() + + +if __name__ == "__main__": + train() diff --git a/examples/lightning/README.md b/examples/lightning/README.md new file mode 100755 index 0000000000000000000000000000000000000000..916d79d5fe28b2d6d19fbd84c19ebb4c22b98769 --- /dev/null +++ b/examples/lightning/README.md @@ -0,0 +1,21 @@ +# Liger-Kernel Example with Lightning Trainer + +## How to Run +```bash +pip install -r requirements.txt + +# For single L40 48GB GPU +python training.py --model Qwen/Qwen2-0.5B-Instruct --num_gpu 1 --max_length 1024 + +# For 8XA100 40GB +python training.py --model meta-llama/Meta-Llama-3-8B --strategy deepspeed +``` + +**Notes** +1. The example uses Llama3 model that requires community license agreement and HuggingFace Hub login. If you want to use Llama3 in this example, please make sure you have done the followings: + * Agree on the community license agreement https://huggingface.co/meta-llama/Meta-Llama-3-8B + * Run `huggingface-cli login` and enter your HuggingFace token +2. The default hyperparameters and configurations for gemma works on single L40 48GB GPU and config for llama work on single node with 8xA100 40GB GPUs. For running on device with less GPU RAM, please consider reducing the per-GPU batch size and/or enable `CPUOffload` in FSDP. + + + \ No newline at end of file diff --git a/examples/lightning/requirements.txt b/examples/lightning/requirements.txt new file mode 100755 index 0000000000000000000000000000000000000000..dbba38e0e05e055aaf7acb8a83d7e8b3bad9f614 --- /dev/null +++ b/examples/lightning/requirements.txt @@ -0,0 +1,8 @@ +lightning +transformers +trl +liger-kernel +torch +triton +deepspeed +tf-keras \ No newline at end of file diff --git a/examples/lightning/training.py b/examples/lightning/training.py new file mode 100755 index 0000000000000000000000000000000000000000..ab11648783f5ae7c3867dfa097ac5880549e24cb --- /dev/null +++ b/examples/lightning/training.py @@ -0,0 +1,281 @@ +import argparse +import math +import os + +from dataclasses import _MISSING_TYPE +from dataclasses import dataclass + +import datasets +import lightning.pytorch as pl +import torch +import transformers + +from lightning.pytorch.strategies import DeepSpeedStrategy +from lightning.pytorch.strategies import FSDPStrategy +from torch.distributed.fsdp import BackwardPrefetch +from torch.distributed.fsdp import MixedPrecision +from torch.utils.data import DataLoader +from transformers.models.llama.modeling_llama import LlamaDecoderLayer +from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer +from trl import DataCollatorForCompletionOnlyLM + +from liger_kernel.transformers import AutoLigerKernelForCausalLM +from liger_kernel.utils import infer_device + +_RETAIN_COLUMNS = {"input_ids", "attention_mask", "labels"} +QUESTION = "" +CHOICES = "" + + +@dataclass +class Args: + model: str = "Qwen/Qwen2-0.5B-Instruct" + data: str = "cais/mmlu" + output_dir: str = "mmlu_finetuning" + max_length: int = 2048 + # for llam3 8B model, deepspeed will OOM with 16 on 8XA100 80G and 8 will OOM with 8XA100 40G + batch_size: int = 4 + lr: float = 6e-6 + weight_decay: float = 0.05 + warmup_ratio: float = 0.1 + seed: int = 42 + strategy: str = "auto" + num_gpu: int = None + + +def warmup_cosine_schedule(warmup_steps, total_steps, min_lr=0): + def lr_lambda(current_step): + if current_step < warmup_steps: + # Linear warmup + return float(current_step) / float(max(1, warmup_steps)) + else: + # Cosine annealing + progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps)) + return max(min_lr, 0.5 * (1 + math.cos(math.pi * progress))) + + return lr_lambda + + +def parse_args() -> Args: + parser = argparse.ArgumentParser() + for k, v in Args.__dataclass_fields__.items(): + parser.add_argument(f"--{k}", type=v.type, default=v.default) + parsed = parser.parse_args() + return Args(**{k: v for k, v in vars(parsed).items() if not isinstance(v, _MISSING_TYPE)}) + + +class LanguageModel(pl.LightningModule): + def __init__(self, args: Args, tokenizer): + super().__init__() + self.args = args + self.tokenizer = tokenizer + self.model = None + + def configure_model(self): + # https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/fsdp.html#speed-up-model-initialization + if self.model is not None: + return + self.model = AutoLigerKernelForCausalLM.from_pretrained( + self.args.model, use_cache=False, ignore_mismatched_sizes=True + ) + if self.args.strategy == "deepspeed": + self.model.train() + self.model.gradient_checkpointing_enable() + + def forward(self, input_ids, attention_mask, labels=None, **kwargs): + return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs) + + def training_step(self, batch): + outputs = self.model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + labels=batch["labels"], + ) + loss = outputs.loss + self.log_dict( + {"train_loss": loss}, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + rank_zero_only=True, + sync_dist=False, + ) + return loss + + def validation_step(self, batch): + outputs = self.model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + labels=batch["labels"], + ) + loss = outputs.loss + self.log_dict( + {"val_loss": outputs.loss}, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + rank_zero_only=True, + sync_dist=True, + ) + return loss + + def configure_optimizers(self): + optimizer = torch.optim.AdamW( + self.parameters(), + lr=self.args.lr, + weight_decay=self.args.weight_decay, + fused=True, + ) + lr_lambda = warmup_cosine_schedule( + warmup_steps=self.trainer.estimated_stepping_batches * self.args.warmup_ratio, + total_steps=self.trainer.estimated_stepping_batches, + min_lr=0, + ) + lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + return { + "optimizer": optimizer, + "lr_scheduler": {"scheduler": lr_scheduler, "interval": "step"}, + } + + +class DataModule(pl.LightningDataModule): + def __init__(self, tokenizer, args: Args): + super().__init__() + self.args = args + self.tokenizer = tokenizer + self.response_template_str = " " + response_prompt = tokenizer.encode(f"{self.response_template_str}", add_special_tokens=False) + self.collator = DataCollatorForCompletionOnlyLM( + tokenizer=tokenizer, + response_template=response_prompt, + pad_to_multiple_of=16, + ) + + def formatting_func(self, example): + output_texts = [] + for i in range(len(example["question"])): + choices = "" + for j in range(len(example["choices"][i])): + choices += f"{j + 1}. {example['choices'][i][j]}; " + s = "Below is a question and multiple choice answers, choices separated by a semicolon. Please select the best answer for the question. " + s += f"{QUESTION}{example['question'][i]} " + s += f"{CHOICES}{choices} " + s += f"{self.response_template_str}{example['answer'][i]}" + output_texts.append(s) + return output_texts + + def tokenize(self, example): + outputs = self.tokenizer( + self.formatting_func(example), + truncation=True, + padding=False, + max_length=self.args.max_length, + ) + return { + "input_ids": outputs["input_ids"], + "attention_mask": outputs["attention_mask"], + } + + def setup(self, stage) -> None: + dataset = datasets.load_dataset(self.args.data, "auxiliary_train") + flattened_data = [ + { + "answer": x["train"]["answer"], + "choices": x["train"]["choices"], + "question": x["train"]["question"], + "subject": x["train"]["subject"], + } + for x in dataset["train"] + ] + dataset = datasets.Dataset.from_list(flattened_data) + dataset = dataset.train_test_split(test_size=4096, seed=self.args.seed) + train_dataset, val_dataset = dataset["train"], dataset["test"] + self.train_dataset = train_dataset.map( + self.tokenize, + remove_columns=list(set(train_dataset.column_names) - _RETAIN_COLUMNS), + batched=True, + batch_size=1, + num_proc=4, + ) + self.val_dataset = val_dataset.map( + self.tokenize, + remove_columns=list(set(val_dataset.column_names) - _RETAIN_COLUMNS), + batched=True, + batch_size=1, + num_proc=4, + ) + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.args.batch_size, + collate_fn=self.collator, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.args.batch_size, + collate_fn=self.collator, + ) + + +def train(): + args = parse_args() + pl.seed_everything(args.seed) + os.makedirs(args.output_dir, exist_ok=True) + + if "Meta-Llama-3-8B" in args.model: + layers = {LlamaDecoderLayer} + elif "Qwen2" in args.model: + layers = {Qwen2DecoderLayer} + else: + layers = {} + raise Warning(f"Unimplemented layer wrap policy for {args.model} in this example") + + if args.strategy == "fsdp": + strategy = FSDPStrategy( + auto_wrap_policy=layers, + sharding_strategy="FULL_SHARD", + backward_prefetch=BackwardPrefetch.BACKWARD_PRE, + sync_module_states=True, + activation_checkpointing_policy=layers, + mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16), + forward_prefetch=True, + ) + precision = None + elif args.strategy == "deepspeed": + strategy = DeepSpeedStrategy(stage=3) + precision = "bf16-mixed" + elif args.strategy == "ddp": + strategy = "ddp" + precision = "bf16-true" + else: + strategy = "auto" + precision = "bf16-true" + + device = infer_device() + trainer = pl.Trainer( + accelerator=device, + strategy=strategy, + devices=(getattr(torch, device).device_count() if args.num_gpu is None else args.num_gpu), + default_root_dir=args.output_dir, + log_every_n_steps=1, + max_epochs=1, + precision=precision, + ) + + tokenizer = transformers.AutoTokenizer.from_pretrained(args.model, padding_side="left", truncation_side="left") + tokenizer.pad_token = tokenizer.eos_token + data_module = DataModule( + tokenizer=tokenizer, + args=args, + ) + model = LanguageModel(args=args, tokenizer=tokenizer) + trainer.fit(model, datamodule=data_module) + + +if __name__ == "__main__": + train() diff --git a/examples/medusa/README.md b/examples/medusa/README.md new file mode 100755 index 0000000000000000000000000000000000000000..6d0ba99eb3c6cf39b5e5a85ba48f38854f409780 --- /dev/null +++ b/examples/medusa/README.md @@ -0,0 +1,72 @@ +# Liger-Kernel Example with Medusa + +Medusa is a simple framework that democratizes the acceleration techniques for LLM generation with multiple decoding heads. [[repo](https://arxiv.org/abs/2401.10774)], [[paper](https://arxiv.org/abs/2401.10774)] + +During training, Medusa requires adding \(k\) decoding heads to the hidden states right before the regular LM head \(h_t\). The \(k\)-th head is used to predict the token in the \((t + k + 1)\)-th position of the next tokens (the original language model head is used to predict the \((t + 1)\)-th position). + +The Liger fused CE kernel is highly effective in this scenario, eliminating the need to materialize logits for each head, which usually consumes a large volume of memory due to the extensive vocabulary size (e.g., for LLaMA-3, the vocabulary size is 128k). The introduction of multiple heads can easily lead to OOM (Out of Memory) issues. However, thanks to the efficient Liger fused CE, which calculates the gradient in place and doesn't materialize the logits, we have observed very effective results. This efficiency opens up more opportunities for multi-token prediction research and development. + + +# Instructions to Run the Training Script + +``` +git clone git@github.com:linkedin/Liger-Kernel.git +cd {PATH_TO_Liger-Kernel}/Liger-Kernel/ +pip install -e . +cd {PATH_TO_Liger-Kernel}/Liger-Kernel/examples/medusa +pip install -r requirements.txt +sh scripts/llama3_8b_medusa.sh +``` + +**Notes** +1. This example uses an optional `use_liger` flag. If true, it does a monkey patch to apply liger kernel with medusa heads. +2. The example uses Llama3 model that requires community license agreement and HuggingFace Hub login. If you want to use Llama3 in this example, please make sure you have done the followings: + * Agree on the community license agreement https://huggingface.co/meta-llama/Meta-Llama-3-8B + * Run `huggingface-cli login` and enter your HuggingFace token +3. The default hyperparameters and configurations work on single node with 8xA100 GPUs. For running on device with less GPU RAM, please consider reducing the per-GPU batch size and/or enable `CPUOffload` in FSDP. +4. We are using a smaller sample of shared GPT data primarily to benchmark performance. The example requires hyperparameter tuning and dataset selection to work effectively, also ensuring the dataset has the same distribution as the LLaMA pretraining data. Welcome contribution to enhance the example code. + + +# Memory Profiling Result + +> **Note:** +> 1. Benchmark conditions: LLaMA 3-8B, Batch Size = 6, Data Type = bf16, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s. + +## Stage1 + +Stage1 refers to Medusa-1 where the backbone model is frozen and only weights of LLM heads are updated. + +``` +# Modify this flag in llama3_8b_medusa.sh to True enables stage1 +--medusa_only_heads True +``` + +### num_head = 3 + +![Memory](./docs/images/Memory_Stage1_num_head_3.png) +![Throughput](./docs/images/Throughput_Stage1_num_head_3.png) + +### num_head = 5 + +![Memory](./docs/images/Memory_Stage1_num_head_5.png) +![Throughput](./docs/images/Throughput_Stage1_num_head_5.png) + +## Stage2 + +``` +# Modify this flag to False in llama3_8b_medusa.sh enables stage2 +--medusa_only_heads False +``` + +Stage2 refers to Medusa-2 where all the model weights are updated incuding backbone model and llm heads. + +### num_head = 3 + +![Memory](./docs/images/Memory_Stage2_num_head_3.png) +![Throughput](./docs/images/Throughput_Stage2_num_head_3.png) + +### num_head = 5 + +![Memory](./docs/images/Memory_Stage2_num_head_5.png) +![Throughput](./docs/images/Throughput_Stage2_num_head_5.png) + diff --git a/examples/medusa/callback.py b/examples/medusa/callback.py new file mode 100755 index 0000000000000000000000000000000000000000..673243b770a871b42d3a99c22e0165e0a5c3b14e --- /dev/null +++ b/examples/medusa/callback.py @@ -0,0 +1,386 @@ +import os +import time + +from dataclasses import dataclass + +import torch +import transformers + +from accelerate.utils.constants import FSDP_SHARDING_STRATEGY +from transformers import TrainerControl +from transformers import TrainerState +from transformers import TrainingArguments + +from liger_kernel.utils import infer_device + +# https://simple.wikipedia.org/wiki/Byte +# For memory, we use binary system +M_BIN_UNIT = 2**20 +# For metrics (tflops), we use decimal system +T_DEC_UNIT = 10**12 + + +def round_to_n_decimal(x, n): + return round(x, n) + + +@dataclass +class Precision: + """ + Precision is a dataclass to store the number of decimal points for each metric. + """ + + n_decimal_time: int + n_decimal_memory: int + n_decimal_TPS: int + n_decimal_MFU: int + + +@dataclass +class State: + """ + State is a dataclass to store the internal state of the efficiency callback. + """ + + n_warmup_steps: int = 0 + total_peak_memory_allocated: float = float("-inf") + total_peak_memory_reserved: float = float("-inf") + + step_start_time: float = 0.0 + elapsed_time: float = 0.0 + + elapsed_step: int = 0 + + step_start_tokens_seen: int = 0 + elapsed_tokens_seen: int = 0 + + step_start_flos: float = 0.0 + elapsed_flos: float = 0.0 + + global_start_step: int = 0 + + +@dataclass +class Time: + """ + Time is a dataclass to store the time-related metrics. + """ + + step: int = 0 + step_time_sec: float = 0.0 + avg_step_time_sec: float = 0.0 + time_to_completion_sec: float = 0.0 + estimated_total_time_sec: float = 0.0 + + +@dataclass +class Memory: + """ + Memory is a dataclass to store the memory-related metrics. + """ + + step_peak_memory_allocated_MB: float = 0.0 + total_peak_memory_allocated_MB: float = 0.0 + + +@dataclass +class TPS: + """ + TPS is a dataclass to store the tokens per second metrics. + """ + + step_tokens_per_second: float = 0.0 + avg_tokens_per_second: float = 0.0 + + +@dataclass +class MFU: + """ + MFU is a dataclass to store the MFU metrics. + """ + + step_MFU: float = 0.0 + avg_MFU: float = 0.0 + + +class EfficiencyCallback(transformers.TrainerCallback): + """ + EfficiencyCallback is a callback to track the efficiency of the training process. + The tracked stats include: step time, memory, throughput, and MFU. + + It requires including `--include_num_input_tokens_seen` and `logging_steps=1` in the training arguments. + + Args: + n_warmup_steps: number of warmup steps + The stats in the first n_warmup_steps will not be added into the aggregated stats + This is because the first few steps might take longer due to jit compliation and other initialization overheads + n_decimal_time: number of decimal points for time + n_decimal_memory: number of decimal points for memory + n_decimal_TPS: number of decimal points for TPS + n_decimal_MFU: number of decimal points for MFU in percentage + """ + + def __init__( + self, + n_warmup_steps=2, + n_decimal_time=2, + n_decimal_memory=2, + n_decimal_TPS=2, + n_decimal_MFU=4, + ): + self.state = State( + n_warmup_steps, + ) + + self.precision = Precision( + n_decimal_time, + n_decimal_memory, + n_decimal_TPS, + n_decimal_MFU, + ) + + self.time = Time() + self.memory = Memory() + self.tps = TPS() + self.mfu = MFU() + self.device = infer_device() + + def on_init_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """ + Event called at the end of the initialization of the [`Trainer`]. + """ + if not args.include_num_input_tokens_seen: + raise Exception( + 'Please pass training argument "--include_num_input_tokens_seen" to track tokens per second' + ) + if args.logging_steps != 1: + raise Exception("Please set logging_steps=1 to track the efficiency metrics accurately") + + def on_train_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + # if loaded from checkpoints, global_start_step is not 1 but state.global_step + self.state.global_start_step = state.global_step + + def on_log( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + logs: dict[str, float], + **kwargs, + ): + if state.global_step < (self.state.global_start_step + self.state.n_warmup_steps): + return + else: + # spread self.time, self.memory, self.tps, self.mfu to logs + # logs.update(self.time.__dict__) + logs.update(self.memory.__dict__) + logs.update(self.tps.__dict__) + # logs.update(self.mfu.__dict__) + + def on_step_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """ + Event called at the beginning of a training step. If using gradient accumulation, one training step might take + several inputs. + """ + # memory + getattr(torch, self.device).reset_peak_memory_stats() + + # time + self.state.step_start_time = time.perf_counter() + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + if state.global_step < (self.state.global_start_step + self.state.n_warmup_steps): + # The end the current step_start_tokens_seen and step_start_flos are the start of next iteration + + # tokens + self.state.step_start_tokens_seen = state.num_input_tokens_seen + # flos + self.state.step_start_flos = state.total_flos + return + + # time + current_time = time.perf_counter() + step_time = current_time - self.state.step_start_time + self.state.elapsed_time += step_time + + # step + global_step = state.global_step + self.state.elapsed_step += 1 + avg_step_time = self.state.elapsed_time / self.state.elapsed_step + + self.time.step = global_step + self.time.step_time_sec = round_to_n_decimal(step_time, self.precision.n_decimal_time) + self.time.avg_step_time_sec = round_to_n_decimal(avg_step_time, self.precision.n_decimal_time) + self.time.time_to_completion_sec = round_to_n_decimal( + avg_step_time * (state.max_steps - global_step), + self.precision.n_decimal_time, + ) + self.time.estimated_total_time_sec = round_to_n_decimal( + avg_step_time * state.max_steps, self.precision.n_decimal_time + ) + + # memory + step_peak_memory_allocated = getattr(torch, self.device).memory.max_memory_allocated() + step_peak_memory_reserved = getattr(torch, self.device).memory.max_memory_reserved() + + self.memory.step_peak_memory_allocated_MB = round_to_n_decimal( + step_peak_memory_allocated / M_BIN_UNIT, self.precision.n_decimal_memory + ) + self.state.total_peak_memory_allocated = max(self.state.total_peak_memory_allocated, step_peak_memory_allocated) + self.memory.total_peak_memory_allocated_MB = round_to_n_decimal( + self.state.total_peak_memory_allocated / M_BIN_UNIT, + self.precision.n_decimal_memory, + ) + + self.memory.step_peak_memory_reserved_MB = round_to_n_decimal( + step_peak_memory_reserved / M_BIN_UNIT, self.precision.n_decimal_memory + ) + + self.state.total_peak_memory_reserved = max(self.state.total_peak_memory_reserved, step_peak_memory_reserved) + + self.memory.total_peak_memory_reserved_MB = round_to_n_decimal( + self.state.total_peak_memory_reserved / M_BIN_UNIT, + self.precision.n_decimal_memory, + ) + + # tokens + step_tokens_seen = state.num_input_tokens_seen - self.state.step_start_tokens_seen + + self.state.elapsed_tokens_seen += step_tokens_seen + + self.tps.step_tokens_per_second = round_to_n_decimal( + step_tokens_seen / step_time, + self.precision.n_decimal_TPS, + ) + + self.tps.avg_tokens_per_second = round_to_n_decimal( + self.state.elapsed_tokens_seen / self.state.elapsed_time, + self.precision.n_decimal_TPS, + ) + + # flos + step_flos = state.total_flos - self.state.step_start_flos + self.state.elapsed_flos += step_flos + + # MFU + # 1. Definition + # + # MFU is defined as (achieved TPS) / (theoretical maximum TPS) = (achieved floating point operations per sec) / (theoretical maximum floating point operations per sec) + # Crucially, the "theoretical maximum" throughput only accounts for the required operations to compute the forward+backward passes, and not rematerialization. MFU therefore allows fair comparisons + # between training runs on different systems, as the numerator is simply the observed tokens-per-second, and the denominator is only dependent on the model architecture and published maximum FLOPs for a given system. + # Ref: https://arxiv.org/pdf/2204.02311 + # The benefit of MFU is that it + # + # 2. Implementation in huggingface + # + # current_flos = 6 * estimate_tokens(input_dict) * num_parameters() + # total_flos = sum(current_flos) # across all GPUs + # Ref: https://github.com/huggingface/transformers/blob/616bb11d487aabc231bb230b245c42214ea4b254/src/transformers/modeling_utils.py#L1196 + # + # 3. Derive MFU on rank 0 + # + # rank_0_flos = tatal_flos / n_gpus = measured_flos / effecitve_n_gpus + # rank_0_MFU = rank_0_flos / step_time + # + # For FSDP, num_parameters() is (1 / n_gpus) of the total parameters. So, the effective_n_gpus = 1 + # For HSDP, num_parameters() is (1 / local_world_size) of the total parameters. So, the effective_n_gpus = n_nodes + # For no sharding and zero-2, num_parameters() is the total parameters. So, the effective_n_gpus = n_gpus + + num_gpus = EfficiencyCallback._get_effective_num_gpus() + step_achieved_tflops = step_flos / step_time / num_gpus / T_DEC_UNIT + + avg_achieved_tflops = self.state.elapsed_flos / self.state.elapsed_time / num_gpus / T_DEC_UNIT + + precision_bits = 16 if args.bf16 or args.fp16 else 32 + gpu_peak_tflops = EfficiencyCallback._get_gpu_peak_tflops(precision_bits) + + self.mfu.step_MFU = round_to_n_decimal(step_achieved_tflops / gpu_peak_tflops, self.precision.n_decimal_MFU) + + self.mfu.avg_MFU = round_to_n_decimal(avg_achieved_tflops / gpu_peak_tflops, self.precision.n_decimal_MFU) + + # The end the current step_start_tokens_seen and step_start_flos are the start of next iteration + + # tokens + self.state.step_start_tokens_seen = state.num_input_tokens_seen + # flos + self.state.step_start_flos = state.total_flos + + @staticmethod + def _get_effective_num_gpus(): + # Calculate the number of effective GPUs for the total FLOPs in order to calculate the single GPU FLOP + world_size = int(os.environ.get("WORLD_SIZE", "1")) + + if transformers.utils.strtobool(os.environ.get("ACCELERATE_USE_FSDP", "false")): + sharding_strategy = os.environ.get("FSDP_SHARDING_STRATEGY", FSDP_SHARDING_STRATEGY[0]).upper() + + # Either specified as string or enum number + if sharding_strategy in { + "FULL_SHARD", + str(FSDP_SHARDING_STRATEGY.index("FULL_SHARD") + 1), + }: + return 1 + + elif sharding_strategy in { + "HYBRID_SHARD", + str(FSDP_SHARDING_STRATEGY.index("HYBRID_SHARD") + 1), + }: + return world_size // int(os.environ.get("LOCAL_WORLD_SIZE", 1)) + else: + return world_size + + assert world_size != 0, ( + "WORLD_SIZE should be set to a positive integer. For single GPU training, please explicitly set WORLD_SIZE=1." + ) + + # TODO: add deepspeed support + return world_size + + @staticmethod + def _get_gpu_peak_tflops(precision_bits: int = 16): + if precision_bits not in {16, 32}: + raise Exception(f"Precision bits {precision_bits} is not supported") + + device_name = getattr(torch, infer_device()).get_device_name() + + if "A100" in device_name: + # data from https://www.nvidia.com/en-us/data-center/a100/ + return 312 if precision_bits == 16 else 156 + elif "H100" in device_name: + # data from https://www.nvidia.com/en-us/data-center/h100/ + # NOTE: Specifications are one-half lower without sparsity. + if "NVL" in device_name: + return 1979 if precision_bits == 16 else 989 + elif "PCIe" in device_name: + return 756 if precision_bits == 16 else 378 + else: # for SXM and other variants + return 989 if precision_bits == 16 else 494 + elif "V100" in device_name: + if "NVL" in device_name: + return 125 + else: + return 112 + return None diff --git a/examples/medusa/docs/images/Memory_Stage1_num_head_3.png b/examples/medusa/docs/images/Memory_Stage1_num_head_3.png new file mode 100755 index 0000000000000000000000000000000000000000..34f044faf2fe47d1cf972bb09e9472f0dae877d7 Binary files /dev/null and b/examples/medusa/docs/images/Memory_Stage1_num_head_3.png differ diff --git a/examples/medusa/docs/images/Memory_Stage1_num_head_5.png b/examples/medusa/docs/images/Memory_Stage1_num_head_5.png new file mode 100755 index 0000000000000000000000000000000000000000..61d66f6b0465c4b4cf83013e69676eb0ed4dd128 Binary files /dev/null and b/examples/medusa/docs/images/Memory_Stage1_num_head_5.png differ diff --git a/examples/medusa/docs/images/Memory_Stage2_num_head_3.png b/examples/medusa/docs/images/Memory_Stage2_num_head_3.png new file mode 100755 index 0000000000000000000000000000000000000000..3b860a887ae3a075c905ae04b444d97af199a9df Binary files /dev/null and b/examples/medusa/docs/images/Memory_Stage2_num_head_3.png differ diff --git a/examples/medusa/docs/images/Memory_Stage2_num_head_5.png b/examples/medusa/docs/images/Memory_Stage2_num_head_5.png new file mode 100755 index 0000000000000000000000000000000000000000..f7920168225fa4351cf190cd71c78844ae4ff77b Binary files /dev/null and b/examples/medusa/docs/images/Memory_Stage2_num_head_5.png differ diff --git a/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png b/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png new file mode 100755 index 0000000000000000000000000000000000000000..68d682b6b8600afe2ec618db91bc47340826d438 Binary files /dev/null and b/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png differ diff --git a/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png b/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png new file mode 100755 index 0000000000000000000000000000000000000000..77058dc0bb9a17f51914054e528f072f7160eae5 Binary files /dev/null and b/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png differ diff --git a/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png b/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png new file mode 100755 index 0000000000000000000000000000000000000000..2595387fd68202192af50ee97afbc464e5721bc9 Binary files /dev/null and b/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png differ diff --git a/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png b/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png new file mode 100755 index 0000000000000000000000000000000000000000..9f9c2543335857485abfba79eee71b2160e2a902 Binary files /dev/null and b/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png differ diff --git a/examples/medusa/fsdp/acc-fsdp.conf b/examples/medusa/fsdp/acc-fsdp.conf new file mode 100755 index 0000000000000000000000000000000000000000..2ed641fe5a0e9e66f20f1a400b25d264f399c702 --- /dev/null +++ b/examples/medusa/fsdp/acc-fsdp.conf @@ -0,0 +1,24 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'yes' +fsdp_config: + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: NO_PREFETCH + fsdp_cpu_ram_efficient_loading: true + fsdp_forward_prefetch: false + fsdp_offload_params: false + fsdp_sharding_strategy: HYBRID_SHARD + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_sync_module_states: true + fsdp_use_orig_params: true +main_training_function: main +mixed_precision: bf16 +rdzv_backend: static +same_network: true +num_machines: 1 +num_processes: 1 +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: yes \ No newline at end of file diff --git a/examples/medusa/medusa_util.py b/examples/medusa/medusa_util.py new file mode 100755 index 0000000000000000000000000000000000000000..7c66e0e088adabc3ff32c187e0c56d44ae10868a --- /dev/null +++ b/examples/medusa/medusa_util.py @@ -0,0 +1,267 @@ +import types + +from typing import List +from typing import Optional + +import torch + +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers import PretrainedConfig +from transformers.modeling_outputs import CausalLMOutputWithPast + +from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss + + +class MedusaConfig(PretrainedConfig): + """ + Configuration class for Medusa model. + + Args: + medusa_num_heads (int, optional): Number of heads for the Medusa layer. Default is 2. + medusa_num_layers (int, optional): Number of Medusa layers. Default is 1. + base_model_name_or_path (str, optional): The name or path of the base model. Default is "lmsys/vicuna-7b-v1.3". + num_unfreezed_layers (int, optional): Number of layers to unfreeze. Default is 0. + **kwargs: Additional keyword arguments to be passed to the parent class constructor. + """ + + def __init__( + model, + medusa_num_heads=4, + medusa_num_layers=1, + base_model_name_or_path="/shared/public/models/Meta-Llama-3-8B", + **kwargs, + ): + super().__init__(**kwargs) + model.medusa_num_heads = medusa_num_heads + model.medusa_num_layers = medusa_num_layers + model.base_model_name_or_path = base_model_name_or_path + + +class ResBlock(nn.Module): + """ + A Residual Block module. + + This module performs a linear transformation followed by a SiLU activation, + and then adds the result to the original input, creating a residual connection. + + Args: + hidden_size (int): The size of the hidden layers in the block. + """ + + def __init__(model, hidden_size): + super().__init__() + model.linear = nn.Linear(hidden_size, hidden_size) + # Initialize as an identity mapping + nn.init.zeros_(model.linear.weight) + # Use SiLU activation to keep consistent with the Llama model + model.act = nn.SiLU() + + def forward(model, x): + """ + Forward pass of the ResBlock. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output after the residual connection and activation. + """ + return x + model.act(model.linear(x)) + + +def calculate_loss_contribution( + loss_i, + i, + medusa_only_heads, + medusa_decay_coefficient, + medusa_heads_coefficient, + medusa_scheduler_coefficient, +): + if i == 0: + return loss_i if not medusa_only_heads else 0 + else: + return loss_i * medusa_decay_coefficient**i * medusa_heads_coefficient * medusa_scheduler_coefficient + + +def add_medusa_heads( + model, + medusa_num_heads=4, + medusa_num_layers=0, + medusa_return: bool = False, + medusa_only_heads: bool = False, + with_liger=True, +): + """ + Args: + model (nn.Module): The base language model to be used. + medusa_num_heads (int, optional): The number of additional tokens to predict. Defaults to 3. + medusa_num_layers (int, optional): The number of ResBlock layers for each Medusa head. Defaults to 0. + medusa_return (bool, optional): If True, returns the Medusa logits; otherwise, the forward pass will use the `lm_head`. Defaults to False. + medusa_only_heads (bool, optional): If True, only the Medusa head weights will be updated during fine-tuning; otherwise, the entire model's weights will be updated. Defaults to False. + with_liger (bool, optional): If True, applies Liger loss. Defaults to True. + """ + hidden_size = model.lm_head.weight.shape[-1] + vocab_size = model.lm_head.weight.shape[0] + model.config.medusa_num_layers = medusa_num_layers + model.config.medusa_num_heads = medusa_num_heads + model.medusa_num_heads = medusa_num_heads + # Create a list of Medusa heads + model.medusa_head = nn.ModuleList( + [ + nn.Sequential( + *([ResBlock(hidden_size) for _ in range(medusa_num_layers)]), + nn.Linear(hidden_size, vocab_size, bias=False), + ) + for _ in range(medusa_num_heads) + ] + ) + + # Ensure medusa_head's dtype and device align with the base_model + model.medusa_head.to(model.dtype).to(model.device) + + for i in range(medusa_num_heads): + # Initialize the weights of each medusa_head using the base model's weights + model.medusa_head[i][-1].weight.data[:] = model.lm_head.weight.data[:] + # logging the model summary + print(model) + model.old_forward = model.forward + + def forward( + model, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + """Forward pass of the MedusaModel. + Returns: + torch.Tensor: A tensor containing predictions from all Medusa heads. + (Optional) Original predictions from the base model's LM head. + """ + loss = 0 + medusa_logits = None + # LOG.debug("medusa_return: %s", medusa_return) + if not medusa_return: + return model.old_forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # Pass input through the base model + if medusa_only_heads: + with torch.no_grad(): + outputs = model.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + # The lm_head will be frozen as well, so it's within the context of torch.no_grad() + if not with_liger: + medusa_logits = [model.lm_head(hidden_states)] + else: + outputs = model.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + if not with_liger: + medusa_logits = [model.lm_head(hidden_states)] + + if not with_liger: + for i in range(model.medusa_num_heads): + medusa_logits.append(model.medusa_head[i](hidden_states)) + medusa_logits = torch.stack(medusa_logits, dim=0) + + if model.training: + # Fix all the coefficients to 1 for now + medusa_scheduler_coefficient = 1 + medusa_heads_coefficient = 1 + medusa_decay_coefficient = 1 + loss = 0 + + if with_liger: + lce = LigerFusedLinearCrossEntropyLoss() + for i in range(model.medusa_num_heads + 1): + shift_hidden_states = ( + hidden_states[..., : -(1 + i), :].contiguous().view(-1, model.config.hidden_size) + ) + shift_labels = labels[..., (1 + i) :].contiguous().view(-1) + + weight = model.lm_head.weight if i == 0 else model.medusa_head[i - 1][-1].weight + loss_i = lce(weight, shift_hidden_states, shift_labels) + + loss += calculate_loss_contribution( + loss_i, + i, + medusa_only_heads, + medusa_decay_coefficient, + medusa_heads_coefficient, + medusa_scheduler_coefficient, + ) + else: + loss_fct = CrossEntropyLoss() + for i in range(model.medusa_num_heads + 1): + medusa_logits_i = medusa_logits[i, :, : -(1 + i)].contiguous().view(-1, medusa_logits.shape[-1]) + medusa_logits_i = medusa_logits_i.float() + medusa_labels = labels[..., (1 + i) :].contiguous().view(-1).to(medusa_logits_i.device) + + loss_i = loss_fct(medusa_logits_i, medusa_labels) + + loss += calculate_loss_contribution( + loss_i, + i, + medusa_only_heads, + medusa_decay_coefficient, + medusa_heads_coefficient, + medusa_scheduler_coefficient, + ) + else: + if model.config.pretraining_tp > 1: + raise NotImplementedError + else: + medusa_logits = [model.lm_head(hidden_states)] + for i in range(model.medusa_num_heads): + medusa_logits.append(model.medusa_head[i](hidden_states)) + + return_dict = return_dict if return_dict is not None else model.config.use_return_dict + + if not return_dict: + output = (medusa_logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=medusa_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + model.forward = types.MethodType(forward, model) diff --git a/examples/medusa/requirements.txt b/examples/medusa/requirements.txt new file mode 100755 index 0000000000000000000000000000000000000000..36ab7fe12b1bf508e8c86fa49b619a346b13ad32 --- /dev/null +++ b/examples/medusa/requirements.txt @@ -0,0 +1,3 @@ +accelerate==1.6.0 +scikit-learn +transformers==4.51.3 diff --git a/examples/medusa/scripts/llama3_8b_medusa.sh b/examples/medusa/scripts/llama3_8b_medusa.sh new file mode 100755 index 0000000000000000000000000000000000000000..3426c0fdcc891d57e010b5f96d8a05f226c4510b --- /dev/null +++ b/examples/medusa/scripts/llama3_8b_medusa.sh @@ -0,0 +1,53 @@ +#!/bin/sh + +export GPUS_PER_NODE=$(nvidia-smi --list-gpus | wc -l) +export LOCAL_WORLD_SIZE=$GPUS_PER_NODE +export NUM_NODES=$WORLD_SIZE +export WORLD_SIZE=$((GPUS_PER_NODE * NUM_NODES)) +echo "Starting training... Num nodes: $NUM_NODES, Num workers: $WORLD_SIZE" + +export OUTPUT_DIR="./llama3-8b-medusa-liger" + +export LOCAL_TRAIN_BATCH_SIZE=4 +export GRADIENT_ACCUMULATION_STEPS=1 +export LR=1e-5 + +export MEDUSA_NUM_HEADS=5 +export MEDUSA_NUM_LAYERS=1 +export MEDUSA_HEADS_COEFFICIENT=0.2 +export MEDUSA_DECAY_COEFFICIENT=0.8 +export MEDUSA_SCHEDULER=constant +export MEDUSA_LR_MULTIPLIER=4.0 + +accelerate launch --config_file fsdp/acc-fsdp.conf \ + --num_machines $NUM_NODES \ + --num_processes $WORLD_SIZE \ + train.py \ + --bf16 True \ + --output_dir $OUTPUT_DIR \ + --num_train_epochs 10 \ + --per_device_train_batch_size $LOCAL_TRAIN_BATCH_SIZE \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \ + --eval_strategy "no" \ + --save_strategy "no" \ + --prediction_loss_only \ + --learning_rate $LR \ + --weight_decay 0. \ + --warmup_ratio 0.04 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --model_max_length 1024 \ + --gradient_checkpointing True \ + --lazy_preprocess False \ + --report_to none \ + --include_num_input_tokens_seen \ + --medusa_num_heads $MEDUSA_NUM_HEADS \ + --medusa_num_layers $MEDUSA_NUM_LAYERS \ + --medusa_heads_coefficient $MEDUSA_HEADS_COEFFICIENT \ + --medusa_decay_coefficient $MEDUSA_DECAY_COEFFICIENT \ + --medusa_scheduler $MEDUSA_SCHEDULER \ + --medusa_lr_multiplier $MEDUSA_LR_MULTIPLIER \ + --medusa_only_heads False \ + --medusa_return True \ + --use_liger True diff --git a/examples/medusa/train.py b/examples/medusa/train.py new file mode 100755 index 0000000000000000000000000000000000000000..1d321343055cf8256b77e7de3c0646a3da24cbf5 --- /dev/null +++ b/examples/medusa/train.py @@ -0,0 +1,381 @@ +# This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright: +# +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py + +import json +import os +import pathlib + +from dataclasses import dataclass +from dataclasses import field +from typing import Dict +from typing import Optional + +import torch +import transformers + +from callback import EfficiencyCallback +from medusa_util import add_medusa_heads +from safetensors.torch import save_file +from sklearn.model_selection import train_test_split +from torch.utils.data import Dataset +from transformers import Trainer +from transformers.trainer_pt_utils import LabelSmoother + +from liger_kernel.transformers import AutoLigerKernelForCausalLM + +IGNORE_TOKEN_ID = LabelSmoother.ignore_index + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="meta-llama/Meta-Llama-3-8B-Instruct") + + +@dataclass +class DataArguments: + data_path: str = field( + default="Aeala/ShareGPT_Vicuna_unfiltered", + metadata={"help": "Path to the training data."}, + ) + eval_data_path: str = field(default=None, metadata={"help": "Path to the evaluation data."}) + lazy_preprocess: bool = True + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + report_to: Optional[str] = None + optim: str = field(default="adamw_torch") + model_max_length: int = field( + default=2048, + metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, + ) + medusa_num_heads: int = field( + default=1, + metadata={"help": "Number of Medusa heads."}, + ) + medusa_num_layers: int = field( + default=1, + metadata={"help": "Number of layers for each Medusa head."}, + ) + medusa_heads_coefficient: float = field( + default=1.0, + metadata={"help": "Coefficient for the Medusa heads."}, + ) + medusa_decay_coefficient: float = field( + default=1.0, + metadata={"help": "Coefficient for the Medusa heads."}, + ) + medusa_scheduler: str = field( + default="constant", + metadata={"help": "Scheduler for the Medusa heads."}, + ) + medusa_lr_multiplier: float = field( + default=0.0, + metadata={"help": "Learning rate multiplier for the Medusa heads."}, + ) + medusa_return: bool = field( + default=False, + metadata={ + "help": "If medusa is not applied, the default is False, and the regular lm_head will be used for single-token prediction." + }, + ) + medusa_only_heads: bool = field( + default=False, + metadata={"help": "If train medusa heads only, default is False, the whole model will be trained"}, + ) + use_liger: bool = field( + default=False, + metadata={"help": "If apply liger kernel to the model."}, + ) + + +local_rank = None + + +def rank0_print(*args): + if local_rank == 0: + print(*args) + + +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): + """ + Save the model's state dictionary to a specified directory. + + Args: + trainer (transformers.Trainer): The Hugging Face Trainer object. + output_dir (str): The directory where the model state dictionary will be saved. + """ + state_dict = trainer.model.state_dict() + if trainer.args.should_save: + cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + + +def preprocess( + sources, + tokenizer: transformers.PreTrainedTokenizer, +) -> Dict: + """ + Preprocesses conversation data and tokenizes it for model input. + + Args: + sources: A list of conversation sources. + tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for tokenization. + + Returns: + Dict: A dictionary containing tokenized inputs, labels, and attention mask. + """ + + # Apply prompt templates + conversations = [] + prompts = [] + # import pdb; pdb.set_trace() + for conversation in sources[:50]: + tokenizer_compatible_conv = [ + { + "role": "user" if c["from"] == "human" else "assistant", + "content": c["value"], + } + for c in conversation["conversations"] + ] + prompt = tokenizer.apply_chat_template(tokenizer_compatible_conv, tokenize=False) + prompts.append(prompt) + conversations.append(tokenizer_compatible_conv) + + # Tokenize conversations + encoding = tokenizer( + prompts, + return_tensors="pt", + padding="max_length", + truncation=True, + return_offsets_mapping=True, + ) + # Set everything to be ignored, except the assistant part + targets = torch.full_like(encoding.input_ids, IGNORE_TOKEN_ID) + input_ids = encoding.input_ids + + # Mask targets. Only compute loss on the assistant outputs. + for conv_index, (conversation, target, prompt) in enumerate(zip(conversations, targets, prompts)): + # print(conv_index) + for turn in conversation: + if turn["role"] == "assistant": + content = turn["content"] + # Unfortunate strip() necessary because chat templates are doing the same. + start = prompt.index(content.strip()) + # stop = start + len(content) + indices = [] + for tok_index, (tok_start, tok_stop) in enumerate(encoding.offset_mapping[conv_index]): + if tok_stop >= start or tok_start < tok_stop: + indices.append(tok_index) + target[indices] = encoding.input_ids[conv_index][indices] + + return dict( + input_ids=input_ids, + labels=targets, + attention_mask=input_ids.ne(tokenizer.pad_token_id), + ) + + +class SupervisedDataset(Dataset): + """Dataset for supervised fine-tuning. + + Args: + raw_data (list): A list of raw data examples. + tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing. + """ + + def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer): + super(SupervisedDataset, self).__init__() + + rank0_print("Formatting inputs...") + sources = raw_data + data_dict = preprocess(sources, tokenizer) + + self.input_ids = data_dict["input_ids"] + self.labels = data_dict["labels"] + self.attention_mask = data_dict["attention_mask"] + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + return dict( + input_ids=self.input_ids[i], + labels=self.labels[i], + attention_mask=self.attention_mask[i], + ) + + +class LazySupervisedDataset(Dataset): + """Lazy dataset for supervised fine-tuning. + + This dataset loads data on-the-fly when requested, which can be memory-efficient but slower. + + Args: + raw_data (list): A list of raw data examples. + tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing. + """ + + def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer): + super(LazySupervisedDataset, self).__init__() + self.tokenizer = tokenizer + + rank0_print("Formatting inputs...Skip in lazy mode") + self.tokenizer = tokenizer + self.raw_data = raw_data + self.cached_data_dict = {} + + def __len__(self): + return len(self.raw_data) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + if i in self.cached_data_dict: + return self.cached_data_dict[i] + + ret = preprocess([self.raw_data[i]], self.tokenizer) + ret = dict( + input_ids=ret["input_ids"][0], + labels=ret["labels"][0], + attention_mask=ret["attention_mask"][0], + ) + self.cached_data_dict[i] = ret + + return ret + + +def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args, test_size=0.05) -> Dict: + """Make dataset and collator for supervised fine-tuning. + + Args: + tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing. + data_args: Data arguments. + test_size: evaluation data ratio (default: 0.05) + + Returns: + dict: A dictionary containing train and eval datasets. + """ + dataset_cls = LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset + rank0_print("Loading data...") + + # Load the entire dataset + train_json = json.load(open(data_args.data_path, "r")) + + # Perform a train-test split based on test_size + train_data, eval_data = train_test_split(train_json, test_size=test_size, random_state=42) + # Create the train and eval datasets + train_dataset = dataset_cls(train_data, tokenizer=tokenizer) + eval_dataset = dataset_cls(eval_data, tokenizer=tokenizer) + + return dict(train_dataset=train_dataset, eval_dataset=eval_dataset) + + +def train(): + global local_rank + + parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + local_rank = training_args.local_rank + + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right", + use_fast=True, + ) + tokenizer.pad_token = tokenizer.unk_token + tokenizer.pad_token = tokenizer.eos_token + + # Making sure the tokenizer works before loading the model. + print(tokenizer(["This is a test", "secondary"], padding=True)) + print(tokenizer.apply_chat_template([{"role": "user", "content": "This is a test"}])) + + def _model_loader(): + # we use a customized model loader to inject medusa heads to FSDP-wrapped model variables properly. + # see https://github.com/linkedin/Liger-Kernel/issues/309#issuecomment-2455077623 for details. + + # Load model + if training_args.use_liger: + model_builder = AutoLigerKernelForCausalLM.from_pretrained + else: + model_builder = transformers.AutoModelForCausalLM.from_pretrained + model = model_builder( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + dtype=torch.bfloat16, + ) + + # Freeze the base model + for param in model.base_model.parameters(): + param.requires_grad = False + + # Inject Medusa heads + add_medusa_heads( + model, + training_args.medusa_num_heads, + training_args.medusa_num_layers, + training_args.medusa_return, + training_args.medusa_only_heads, + training_args.use_liger, + ) + return model + + # Format output dir + training_args.output_dir = f"{training_args.output_dir}_medusa_mlp_{model_args.model_name_or_path.split('/')[-1]}_medusa_{training_args.medusa_num_heads}_lr_{training_args.learning_rate}_layers_{training_args.medusa_num_layers}" + + # Load data + data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) + + # Start trainner + trainer = Trainer( + model_init=_model_loader, + tokenizer=tokenizer, + args=training_args, + callbacks=[EfficiencyCallback()], + **data_module, + ) + + if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + + if training_args.medusa_return and training_args.medusa_only_heads: + # Save only the updated head without saving the backbone model + state_dict = { + k.replace("medusa_head.", ""): v.to(torch.bfloat16) + for k, v in trainer.accelerator.get_state_dict(trainer.model).items() + if "medusa_head" in k + } + + # Save Medusa heads + if local_rank == 0: + save_file( + state_dict, + os.path.join(training_args.output_dir, "medusa_lm_head.safetensors"), + ) + trainer.accelerator.wait_for_everyone() + else: + # Save the whole model weight + trainer.save_model(training_args.output_dir) + + +if __name__ == "__main__": + train() diff --git a/licenses/LICENSE-Apache-2.0 b/licenses/LICENSE-Apache-2.0 new file mode 100755 index 0000000000000000000000000000000000000000..0328c5ff05074b77adab24051b423e722ac9941c --- /dev/null +++ b/licenses/LICENSE-Apache-2.0 @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [2024-] [Unsloth AI, Daniel Han-Chen & Michael Han-Chen] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/licenses/LICENSE-MIT-AutoAWQ b/licenses/LICENSE-MIT-AutoAWQ new file mode 100755 index 0000000000000000000000000000000000000000..c8de3cf7f0202fc59b57dbdbee9ac936756e4a29 --- /dev/null +++ b/licenses/LICENSE-MIT-AutoAWQ @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 MIT HAN Lab + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-MIT-Efficient-Cross-Entropy b/licenses/LICENSE-MIT-Efficient-Cross-Entropy new file mode 100755 index 0000000000000000000000000000000000000000..17736429bcfbc11fa9d9bdf9ca549f5ea5a2c8a4 --- /dev/null +++ b/licenses/LICENSE-MIT-Efficient-Cross-Entropy @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 mgmalek + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-MIT-llmc b/licenses/LICENSE-MIT-llmc new file mode 100755 index 0000000000000000000000000000000000000000..99d8f1f022950f0dc55f01b996d219c122ac2db6 --- /dev/null +++ b/licenses/LICENSE-MIT-llmc @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Andrej Karpathy + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-MIT-triton b/licenses/LICENSE-MIT-triton new file mode 100755 index 0000000000000000000000000000000000000000..0f3852f090ae2ef4dd2c9734669cb404a2f788da --- /dev/null +++ b/licenses/LICENSE-MIT-triton @@ -0,0 +1,23 @@ +/* +* Copyright 2018-2020 Philippe Tillet +* Copyright 2020-2022 OpenAI +* +* Permission is hereby granted, free of charge, to any person obtaining +* a copy of this software and associated documentation files +* (the "Software"), to deal in the Software without restriction, +* including without limitation the rights to use, copy, modify, merge, +* publish, distribute, sublicense, and/or sell copies of the Software, +* and to permit persons to whom the Software is furnished to do so, +* subject to the following conditions: +* +* The above copyright notice and this permission notice shall be +* included in all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +*/ \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100755 index 0000000000000000000000000000000000000000..9a1ba6607eb0e9a37ffdae5f1b4c43efdd1397c0 --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,87 @@ +site_name: Liger-Kernel Docs +site_url: https://linkedin.github.io/Liger-Kernel/ +site_author: LinkedIn +site_description: Efficient Triton Kernels for LLM Training + +theme: + name: material + font: + text: Merriweather Sans + code: Red Hat Mono + features: + - navigation.footer + - toc.follow + - navigation.top + - navigation.sections + palette: + # Dark Mode + - scheme: slate + toggle: + icon: material/weather-sunny + name: Dark mode + primary: green + accent: deep purple + + # Light Mode + - scheme: default + toggle: + icon: material/weather-night + name: Light mode + primary: blue + accent: deep purple + +nav: + - Home: index.md + - Examples: Examples.md + - Getting Started: Getting-Started.md + - High Level APIs: High-Level-APIs.md + - Low Level APIs: Low-Level-APIs.md + - Contributing: contributing.md + - Acknowledgment: acknowledgement.md + - License: license.md + +markdown_extensions: + - attr_list + - toc: + permalink: true + - pymdownx.highlight: + anchor_linenums: true + line_spans: __span + pygments_lang_class: true + - pymdownx.inlinehilite + - pymdownx.snippets + - pymdownx.superfences: + custom_fences: + - name: mermaid + class: mermaid + format: !!python/name:pymdownx.superfences.fence_code_format + - pymdownx.tabbed: + alternate_style: true + - admonition + - pymdownx.details + +plugins: + - search + - mkdocstrings: + handlers: + python: + paths: [src] + options: + show_root_heading: true + show_source: true + docstring_style: google + docstring_section_style: table + heading_level: 3 + show_signature_annotations: false # Hides type annotations to save space + separate_signature: true # Separates signature from description + + +# Repository +repo_name: linkedin/Liger-Kernel +repo_url: https://github.com/linkedin/Liger-Kernel +edit_uri: edit/main/docs/ + +extra: + social: + - icon: simple/github + link: https://github.com/linkedin/Liger-Kernel diff --git a/pyproject.toml b/pyproject.toml new file mode 100755 index 0000000000000000000000000000000000000000..b613ff44d6037857e5c422307d655554060f8d03 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,86 @@ +[build-system] +requires = ["setuptools>=42", "wheel", "setuptools-scm"] +build-backend = "setuptools.build_meta" + +[project] +name = "liger_kernel" +version = "0.7.0" +description = "Efficient Triton kernels for LLM Training" +urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" } +readme = { file = "README.md", content-type = "text/markdown" } +license = { file = "LICENSE" } +dynamic = ["dependencies", "optional-dependencies"] + +[tool.setuptools] +package-dir = {"" = "src"} + +[tool.setuptools.packages.find] +where = ["src"] +include = ["liger_kernel*"] +namespaces = false + +[tool.pytest.ini_options] +pythonpath = ["src", "."] +asyncio_mode = "auto" +log_cli = true +log_cli_level = "INFO" +addopts = [ + "--cov=src/liger_kernel", + "--cov-report=term-missing", + "--cov-report=html", + "--cov-config=pyproject.toml", + "--durations=0" +] +python_files = "test_*.py" +testpaths = ["test/"] + +[tool.coverage.run] +branch = true +parallel = true +source = ["src/liger_kernel"] +# xdist uses subprocesses; "multiprocessing" is a safe concurrency choice +concurrency = ["multiprocessing"] + +[tool.coverage.paths] +liger_kernel = [ + "src/liger_kernel", + "*/site-packages/liger_kernel" +] + +[tool.coverage.report] +omit = ["test/*"] +show_missing = true +skip_covered = false + + +[tool.ruff] +line-length = 120 +target-version = "py310" +respect-gitignore = true +src = ["src"] + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "F", # pyflakes + "I", # isort +] +ignore = ["E501", "B006", "E731", "A002", "E203"] + +exclude = [ + ".git", + "__pycache__", + "benchmark_internal/others", + ".venv", +] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" + +[tool.ruff.lint.isort] +known-first-party = ["liger_kernel"] +force-single-line = true +lines-between-types = 1 diff --git a/setup.py b/setup.py new file mode 100755 index 0000000000000000000000000000000000000000..bf3457eebf0dc9cb1e698c39e55755e186a4d1fb --- /dev/null +++ b/setup.py @@ -0,0 +1,132 @@ +# setup.py + +import subprocess + +from typing import Literal + +from setuptools import setup + + +def get_default_dependencies(): + """Determine the appropriate dependencies based on detected hardware.""" + platform = get_platform() + + if platform in ["cuda", "cpu"]: + return [ + "torch>=2.1.2", + "triton>=2.3.1", + ] + elif platform == "rocm": + return [ + "triton>=3.0.0", + ] + elif platform == "xpu": + return [ + "torch>=2.6.0", + ] + # TODO: Currently, triton-ascend is not compatible with torch 2.7.1. We will upgrade it later. + elif platform == "npu": + return ["torch==2.6.0", "torch_npu==2.6.0", "triton-ascend"] + + +def get_optional_dependencies(): + """Get optional dependency groups.""" + return { + "dev": [ + "transformers>=4.52.0", + "matplotlib>=3.7.2", + "ruff>=0.12.0", + "pytest>=7.1.2", + "pytest-xdist", + "pytest-cov", + "pytest-asyncio", + "pytest-rerunfailures", + "datasets>=2.19.2", + "seaborn", + "mkdocs-material", + "torchvision>=0.20", + "prek>=0.2.28", + ] + } + + +def is_xpu_available(): + """ + Check if Intel XPU is available. + xpu-smi is often missing right now. + """ + try: + subprocess.run(["xpu-smi"], check=True) + return True + except (subprocess.SubprocessError, FileNotFoundError): + pass + + try: + result = subprocess.run("sycl-ls", check=True, capture_output=True, shell=True) + if "level_zero:gpu" in result.stdout.decode(): + return True + except (subprocess.SubprocessError, FileNotFoundError): + pass + + return False + + +def is_ascend_available() -> bool: + """Best-effort Ascend detection. + + Checks for common Ascend environment variables and a possible `npu-smi` + utility if present. + """ + try: + subprocess.run(["npu-smi", "info"], check=True) + return True + except (subprocess.SubprocessError, FileNotFoundError): + pass + return False + + +def get_platform() -> Literal["cuda", "rocm", "cpu", "xpu", "npu"]: + """ + Detect whether the system has NVIDIA or AMD GPU without torch dependency. + """ + # Try nvidia-smi first + try: + subprocess.run(["nvidia-smi"], check=True) + print("NVIDIA GPU detected") + return "cuda" + except (subprocess.SubprocessError, FileNotFoundError): + # If nvidia-smi fails, check for ROCm + try: + subprocess.run(["rocm-smi"], check=True) + print("ROCm GPU detected") + return "rocm" + except (subprocess.SubprocessError, FileNotFoundError): + if is_xpu_available(): + print("Intel GPU detected") + return "xpu" + elif is_ascend_available(): + print("Ascend NPU detected") + return "npu" + else: + print("No GPU detected") + return "cpu" + + +setup( + name="liger_kernel", + package_dir={"": "src"}, + packages=["liger_kernel"], + install_requires=get_default_dependencies(), + extras_require=get_optional_dependencies(), + classifiers=[ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "Programming Language :: Python :: 3", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries :: Python Modules", + "License :: OSI Approved :: BSD-2-Clause Software License", + "Operating System :: OS Independent", + ], +) diff --git a/src/liger_kernel/__init__.py b/src/liger_kernel/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/liger_kernel/chunked_loss/README.md b/src/liger_kernel/chunked_loss/README.md new file mode 100755 index 0000000000000000000000000000000000000000..1dd7037f2dec6cf6189c224111f6825c0d265e2d --- /dev/null +++ b/src/liger_kernel/chunked_loss/README.md @@ -0,0 +1,25 @@ +# Liger FlexChunkLoss: Alignment and Distillation loss + +Liger FlexChunkLoss offers a versatile interface, delivering up to 80% memory savings and a 10% throughput boost for post-training loss functions, including alignment (DPO, ORPO, CPO, KTO) and very soon, distillation. Its flexible design supports custom losses, ensuring efficiency gains across diverse use cases. + +### User interface + +FlexChunkLoss offers two flexible usage options: + +1. **Via `Liger[Custom Loss]Trainer`** + For example, by simply replacing the HuggingFace `ORPOTrainer` with `LigerORPOTrainer` in your code, you can leverage our optimized ORPO implementation and immediately benefit from improved performance. + +2. **Using `nn.Module` Implementations of Custom Loss Functions** + Explore the [LigerORPOTrainer implementation](https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/orpo_trainer.py) to see how the modular design integrates custom loss functions seamlessly. + +### What's under the hood? + +We employ chunking and fused kernel optimizations to enhance performance. By fusing the final linear layer with loss computation and calculating backward gradients during the forward pass, we significantly reduce the need for storing intermediate activations. All operations are implemented in PyTorch, leveraging `torch.compile` to streamline kernel execution without relying on extensive low-level optimizations. Additionally, we minimize `torch.compile` recompilations to reduce overhead and ensure consistent performance gains. + +### Extending to custom loss functions + +We provide two base classes: `LigerFusedLinearPreferenceBase` for alignment use cases and `LigerFusedLinearDistillationBase` for distillation use cases. These base classes manage chunking, kernel fusions, and Torch compilation. + +To implement a custom loss function, you need to create a subclass that defines the custom preference or distillation loss function, capable of processing a given input chunk. The base class will take care of the optimizations, handling most of the heavy lifting for you. + +For a working example, refer to the [ORPO loss implementation](https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/chunked_loss/orpo_loss.py). \ No newline at end of file diff --git a/src/liger_kernel/chunked_loss/__init__.py b/src/liger_kernel/chunked_loss/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..d3624adbbfb6455f245fa8c98e191d6f0db03fa5 --- /dev/null +++ b/src/liger_kernel/chunked_loss/__init__.py @@ -0,0 +1,8 @@ +from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityLoss # noqa:F401 +from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401 +from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401 +from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401 +from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDLoss # noqa: F401 +from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401 +from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401 +from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401 diff --git a/src/liger_kernel/chunked_loss/cosine_similarity_loss.py b/src/liger_kernel/chunked_loss/cosine_similarity_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..553ca34237cf98460bb27554ee65e88ef2fe9f92 --- /dev/null +++ b/src/liger_kernel/chunked_loss/cosine_similarity_loss.py @@ -0,0 +1,142 @@ +from typing import Tuple +from typing import Union + +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase + + +class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase): + @staticmethod + def distillation_loss_fn( + student_logits, + teacher_logits, + target=None, + ignore_index=None, + beta=1.0, + ): + """ + Compute Cosine loss (Cosine Similarity Loss). + Args: + student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,). + teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,). + beta: Coefficient beta of generalized Cosine Similarity in the interval [0, 1]. Default: `1.0` (float): . + Returns: + torch.Tensor: cosine similarity loss + """ + student_norm = F.normalize(student_logits, p=2, dim=-1) + teacher_norm = F.normalize(teacher_logits, p=2, dim=-1) + + cosine_sim = F.cosine_similarity(student_norm, teacher_norm, dim=-1) + loss = beta * (1 - cosine_sim) + return loss.sum() + + @classmethod + def forward( + cls, + ctx, + student_input: torch.Tensor, + student_weight: torch.Tensor, + teacher_input: torch.Tensor, + teacher_weight: torch.Tensor, + true_labels: torch.LongTensor, + student_bias: torch.Tensor, + teacher_bias: torch.Tensor, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + compiled: bool = True, + chunk_size: int = 1024, + return_soft_hard_loss: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + return super().forward( + cls=cls, + ctx=ctx, + student_input=student_input, + student_weight=student_weight, + teacher_input=teacher_input, + teacher_weight=teacher_weight, + target=true_labels, + student_bias=student_bias, + teacher_bias=teacher_bias, + chunk_size=chunk_size, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + beta=beta, + ignore_index=ignore_index, + temperature=temperature, + compiled=compiled, + return_soft_hard_loss=return_soft_hard_loss, + ) + + @staticmethod + def backward(ctx, grad_output, *args): + grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6] + + return ( + *grads, + None, # teacher_bias + None, # weight_hard_loss + None, # weight_soft_loss + None, # beta + None, # ignore_index + None, # temperature + None, # compiled + None, # chunk_size + None, # return_soft_hard_loss + ) + + +class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module): + def __init__( + self, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + compiled: bool = True, + chunk_size: int = 1024, + return_soft_hard_loss: bool = False, + ): + super().__init__() + assert temperature != 0, "Temperature cannot be 0." + self.weight_hard_loss = weight_hard_loss + self.weight_soft_loss = weight_soft_loss + self.ignore_index = ignore_index + self.temperature = temperature + self.compiled = compiled + self.beta = beta + self.chunk_size = chunk_size + self.return_soft_hard_loss = return_soft_hard_loss + + def forward( + self, + student_input: torch.Tensor, + student_weight: torch.Tensor, + teacher_input: torch.Tensor, + teacher_weight: torch.Tensor, + true_labels: torch.LongTensor, + student_bias: torch.Tensor = None, + teacher_bias: torch.Tensor = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + return LigerFusedLinearCosineSimilarityFunction.apply( + student_input, + student_weight, + teacher_input, + teacher_weight, + true_labels, + student_bias, + teacher_bias, + self.weight_hard_loss, + self.weight_soft_loss, + self.beta, + self.ignore_index, + self.temperature, + self.compiled, + self.chunk_size, + self.return_soft_hard_loss, + ) diff --git a/src/liger_kernel/chunked_loss/cpo_loss.py b/src/liger_kernel/chunked_loss/cpo_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..1f0a154974ee689a34db3df367c4e54c20f67916 --- /dev/null +++ b/src/liger_kernel/chunked_loss/cpo_loss.py @@ -0,0 +1,157 @@ +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase + + +class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase): + @staticmethod + def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1, label_smoothing=0.0): + """ + Paper: https://arxiv.org/pdf/2401.08417 + + Formula: + L(π_θ; U) = -E_(x,y_w,y_l)~D[log σ(β log π_θ(y_w|x) - β log π_θ(y_l|x))] + + Where: + - π_θ(y|x): Policy (model) probability + - y_w: Chosen sequence + - y_l: Rejected sequence + - σ: Sigmoid function + - β: Temperature parameter + - E: Expected value over the dataset D + - D: Dataset of preferences + + Args: + chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). + rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). + full_target (torch.Tensor): Non chunked full target tensor + beta (float): Weight for the CPO loss + label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0. + """ + logits = beta * (chosen_logps - rejected_logps) + loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / ( + full_target.shape[0] // 2 + ) + + chosen_rewards = beta * chosen_logps + rejected_rewards = beta * rejected_logps + + return loss, chosen_rewards, rejected_rewards + + @classmethod + def forward( + cls, + ctx, + _input, + weight, + target, + bias=None, + ignore_index=-100, + beta=0.1, + alpha=1.0, + label_smoothing=0.0, + compute_nll_loss=True, + compiled=True, + average_log_prob=False, + chunk_size=1, + ): + """ + Fused linear layer with CPO loss. + Args: + _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size) + weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size) + target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,) + bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,) + ignore_index (int): Index to ignore in loss computation + beta (float): Weight for the odds ratio loss + alpha (float): Weight for the alpha parameter + label_smoothing (float): Label smoothing factor + compute_nll_loss (bool): Whether to compute the NLL loss + compiled (bool): Whether to use torch compile + average_log_prob (bool): Whether to average the log probability per non-masked token + chunk_size (int): Size of chunks for processing. + Returns: + torch.Tensor: Computed loss + """ + return super().forward( + cls=cls, + ctx=ctx, + _input=_input, + weight=weight, + target=target, + bias=bias, + ignore_index=ignore_index, + alpha=alpha, + beta=beta, + label_smoothing=label_smoothing, + compute_nll_loss=compute_nll_loss, + average_log_prob=average_log_prob, + compiled=compiled, + chunk_size=chunk_size, + ) + + @staticmethod + def backward(ctx, *grad_output): + grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] + return *grads, None, None, None, None, None, None, None, None + + +class LigerFusedLinearCPOLoss(torch.nn.Module): + """ + Fused linear layer with CPO loss. + """ + + def __init__( + self, + ignore_index: int = -100, + beta: float = 0.1, + alpha: float = 1.0, + label_smoothing: float = 0.0, + compute_nll_loss: bool = True, + compiled: bool = True, + average_log_prob: bool = False, + chunk_size: int = 1, + ): + """ + Args: + ignore_index (int): Index to ignore in the loss. + beta (float): Weight for the odds ratio loss. + alpha (float): Weight for the alpha parameter. + label_smoothing (float): Label smoothing factor. + compute_nll_loss (bool): Whether to compute the NLL loss. + compiled (bool): Whether to use the torch compiled kernel. + average_log_prob (bool): Whether to average the log probability per non-masked token. + chunk_size (int): Size of chunks for processing. + """ + super().__init__() + self.ignore_index = ignore_index + self.beta = beta + self.alpha = alpha + self.label_smoothing = label_smoothing + self.compute_nll_loss = compute_nll_loss + self.compiled = compiled + self.average_log_prob = average_log_prob + self.chunk_size = chunk_size + + def forward( + self, + lin_weight, + _input, + target, + bias=None, + ): + return LigerFusedLinearCPOFunction.apply( + _input, + lin_weight, + target, + bias, + self.ignore_index, + self.beta, + self.alpha, + self.label_smoothing, + self.compute_nll_loss, + self.compiled, + self.average_log_prob, + self.chunk_size, + ) diff --git a/src/liger_kernel/chunked_loss/dpo_loss.py b/src/liger_kernel/chunked_loss/dpo_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..f7a14e539e45130b1301be481f71d5100dceecc7 --- /dev/null +++ b/src/liger_kernel/chunked_loss/dpo_loss.py @@ -0,0 +1,229 @@ +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase + + +class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase): + @staticmethod + def preference_loss_fn( + chosen_logps, + rejected_logps, + full_target, + ref_chosen_logps=None, + ref_rejected_logps=None, + beta=0.1, + loss_type="sigmoid", + ): + """ + Paper: https://arxiv.org/pdf/2305.18290 + + Formula: + L_DPO = -E[ log_sigmoid( β * (log(π(y_w|x)/π_ref(y_w|x)) - log(π(y_l|x)/π_ref(y_l|x))) ) ] + + Where: + - π(y|x): Policy (model) probability + - π_ref(y|x): Reference model probability + - y_w: Chosen sequence + - y_l: Rejected sequence + - β: Weight for the direct preference loss + - E: Expected value over the dataset + + Args: + chosen_logps: Log probabilities of chosen tokens (batch_size,) + rejected_logps: Log probabilities of rejected tokens (batch_size,) + full_target: Non chunked full target tensor + ref_chosen_logps: Reference log probs of chosen tokens (batch_size,) + ref_rejected_logps: Reference log probs of rejected tokens (batch_size,) + beta: Weight for the direct preference loss + """ + + if ref_chosen_logps is None: + ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device) + if ref_rejected_logps is None: + ref_rejected_logps = torch.tensor(0.0, device=rejected_logps.device) + + chosen_logratios = chosen_logps - ref_chosen_logps + rejected_logratios = rejected_logps - ref_rejected_logps + + chosen_rewards = beta * chosen_logratios + rejected_rewards = beta * rejected_logratios + + if loss_type == "sigmoid": + logits_diff = beta * (chosen_logratios - rejected_logratios) + loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2) + + elif loss_type == "apo_zero": + # Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266) + # Use this loss when you believe the chosen outputs are better than your model's default output + losses_chosen = 1 - F.sigmoid(beta * chosen_logratios) # Increase chosen likelihood + losses_rejected = F.sigmoid(beta * rejected_logratios) + losses = losses_chosen + losses_rejected + loss = losses.sum() / (full_target.shape[0] // 2) + + elif loss_type == "apo_down": + # Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266) + # Use this loss when you believe the chosen outputs are worse than your model's default output. + # Decrease chosen likelihood and decrease rejected likelihood more + losses_chosen = F.sigmoid(beta * chosen_logratios) + losses_rejected = 1 - F.sigmoid(beta * (chosen_logratios - rejected_logratios)) + losses = losses_chosen + losses_rejected + loss = losses.sum() / (full_target.shape[0] // 2) + + elif loss_type == "sppo_hard": + # In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach, + # estimated using the PairRM score. The probability calculation is conducted outside of the trainer class. + # The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is + # set to 1 for the winner and 0 for the loser. + a = chosen_logps - ref_chosen_logps + b = rejected_logps - ref_rejected_logps + losses = (a - 0.5 / beta) ** 2 + (b + 0.5 / beta) ** 2 + loss = losses.sum() / (full_target.shape[0] // 2) + + elif loss_type == "nca_pair": + losses = ( + -F.logsigmoid(chosen_rewards) + - 0.5 * F.logsigmoid(-chosen_rewards) + - 0.5 * F.logsigmoid(-rejected_rewards) + ) + loss = losses.sum() / (full_target.shape[0] // 2) + + else: + raise ValueError( + f"Unsupported loss_type: {loss_type}. Supported types are: sigmoid, apo_zero, apo_down, sppo_hard, nca_pair" + ) + + return loss, chosen_rewards, rejected_rewards + + @classmethod + def forward( + cls, + ctx, + _input, + weight, + target, + bias=None, + ref_input=None, + ref_weight=None, + ref_bias=None, + ignore_index=-100, + beta=0.1, + compute_nll_loss=False, + compiled=True, + use_ref_model=True, + average_log_prob=False, + chunk_size=1, + loss_type="sigmoid", + ): + """ + Fused linear layer with DPO loss. + Args: + _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size) + weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size) + target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,) + bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,) + ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size) + ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size) + ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,) + ignore_index (int): Index to ignore in loss computation + beta (float): Weight for the odds ratio loss + compute_nll_loss (bool): Whether to compute the NLL loss + compiled (bool): Whether to use torch compile + use_ref_model (bool): Whether to use a reference model + average_log_prob (bool): Whether to average the log probability per non-masked token + chunk_size (int): Size of chunks for processing. + Returns: + torch.Tensor: Computed loss + """ + return super().forward( + cls=cls, + ctx=ctx, + _input=_input, + weight=weight, + target=target, + bias=bias, + ignore_index=ignore_index, + beta=beta, + compute_nll_loss=compute_nll_loss, + compiled=compiled, + use_ref_model=use_ref_model, + ref_input=ref_input, + ref_weight=ref_weight, + ref_bias=ref_bias, + average_log_prob=average_log_prob, + chunk_size=chunk_size, + loss_type=loss_type, + ) + + @staticmethod + def backward(ctx, *grad_output): + grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] + return *grads, None, None, None, None, None, None, None, None, None, None, None + + +class LigerFusedLinearDPOLoss(torch.nn.Module): + """ + Fused linear layer with DPO loss. + """ + + def __init__( + self, + ignore_index: int = -100, + beta: float = 0.1, + compute_nll_loss: bool = False, + compiled: bool = True, + use_ref_model: bool = True, + average_log_prob: bool = False, + chunk_size: int = 1, + loss_type: str = "sigmoid", + ): + """ + Args: + ignore_index (int): Index to ignore in the loss. + beta (float): Weight for the odds ratio loss. + compute_nll_loss (bool): Whether to compute the NLL loss. + compiled (bool): Whether to use the torch compiled kernel. + use_ref_model (bool): Whether to use a reference model for the DPO loss. + average_log_prob (bool): Whether to average the log probability per non-masked token. + chunk_size (int): Size of chunks for processing. + """ + super().__init__() + self.ignore_index = ignore_index + self.beta = beta + self.compute_nll_loss = compute_nll_loss + self.compiled = compiled + self.use_ref_model = use_ref_model + self.average_log_prob = average_log_prob + self.chunk_size = chunk_size + self.loss_type = loss_type + supported_loss_types = {"sigmoid", "apo_zero", "apo_down", "sppo_hard", "nca_pair"} + if self.loss_type not in supported_loss_types: + raise ValueError(f"Unsupported loss_type: {self.loss_type}. Supported types are: {supported_loss_types}") + + def forward( + self, + lin_weight, + _input, + target, + bias=None, + ref_input=None, + ref_weight=None, + ref_bias=None, + ): + return LigerFusedLinearDPOFunction.apply( + _input, + lin_weight, + target, + bias, + ref_input, + ref_weight, + ref_bias, + self.ignore_index, + self.beta, + self.compute_nll_loss, + self.compiled, + self.use_ref_model, + self.average_log_prob, + self.chunk_size, + self.loss_type, + ) diff --git a/src/liger_kernel/chunked_loss/functional.py b/src/liger_kernel/chunked_loss/functional.py new file mode 100755 index 0000000000000000000000000000000000000000..722e60d4fb1c440af5e1c61909c1faadd28a1e9e --- /dev/null +++ b/src/liger_kernel/chunked_loss/functional.py @@ -0,0 +1,17 @@ +from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityFunction +from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction +from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction +from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction +from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction +from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction +from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction +from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction + +liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply +liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply +liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply +liger_fused_linear_cosine = LigerFusedLinearCosineSimilarityFunction.apply +liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply +liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply +liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply +liger_fused_linear_grpo = LigerFusedLinearGRPOFunction.apply diff --git a/src/liger_kernel/chunked_loss/fused_linear_distillation.py b/src/liger_kernel/chunked_loss/fused_linear_distillation.py new file mode 100755 index 0000000000000000000000000000000000000000..c58f9320a7301536689b0538b2f8de432782db77 --- /dev/null +++ b/src/liger_kernel/chunked_loss/fused_linear_distillation.py @@ -0,0 +1,299 @@ +from abc import abstractmethod +from functools import partial +from typing import Tuple +from typing import Union + +import torch + +from torch.nn import functional as F + + +class LigerFusedLinearDistillationBase(torch.autograd.Function): + @abstractmethod + def distillation_loss_fn( + student_logits, + teacher_logits, + target=None, + ignore_index=None, + ): + """ + Compute distillation loss. + Args: + student_logits (torch.Tensor): Raw (temperature-scaled) logits of student tokens. Shape: (batch_size * seq_len, vocab_size). + teacher_logits (torch.Tensor): Raw (temperature-scaled) logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size). + Returns: + torch.Tensor: Sum of distillation losses for the chunk. The class will handle + converting this to mean loss by dividing by the full batch size * sequence length in _compute_loss. + """ + raise NotImplementedError("Distillation loss function must be implemented.") + + @staticmethod + def chunk_forward( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias=None, + teacher_bias=None, + ignore_index=-100, + compute_ce_loss=True, + ): + # Student + student_logits_chunk = student_input_chunk @ student_weight.t() + if student_bias is not None: + student_logits_chunk += student_bias + student_log_probs_chunk = F.log_softmax(student_logits_chunk.float(), dim=-1) + + # Teacher + with torch.no_grad(): + teacher_logits_chunk = teacher_input_chunk @ teacher_weight.t() + if teacher_bias is not None: + teacher_logits_chunk += teacher_bias + + # The hard/task loss + ce_loss = 0.0 + if compute_ce_loss: + ce_loss = F.nll_loss( + student_log_probs_chunk.view(-1, student_log_probs_chunk.shape[-1]), + target_chunk.view(-1), + reduction="sum", + ignore_index=ignore_index, + ) + + return student_logits_chunk, teacher_logits_chunk, ce_loss + + @staticmethod + def _compute_loss( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias=None, + teacher_bias=None, + distillation_loss_fn=None, + full_target=None, + ignore_index=-100, + weight_hard_loss=0.5, + weight_soft_loss=0.5, + compute_ce_loss=True, + temperature=1, + **loss_kwargs, + ): + """ + Compute the total loss for a chunk of input and target, while using an knowledge distillation loss function. + Args: + distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. + student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size). + student_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, student_hidden_size). + teacher_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, teacher_hidden_size). + teacher_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, teacher_hidden_size). + target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,). + student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + full_target (torch.Tensor): Full target tensor. Shape: (batch_size * sequence_length,). + ignore_index (int): Index to ignore for loss computation. + weight_hard_loss (float): Weight for hard loss. + weight_soft_loss (float): Weight for soft loss. + compute_ce_loss (bool): Whether to compute CE loss. + temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale) + loss_kwargs (dict): Additional arguments for the loss function. + """ + ( + student_logits_chunk, + teacher_logits_chunk, + hard_loss, + ) = LigerFusedLinearDistillationBase.chunk_forward( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias=student_bias, + teacher_bias=teacher_bias, + ignore_index=ignore_index, + compute_ce_loss=compute_ce_loss, + ) + + student_logits_chunk /= temperature + teacher_logits_chunk /= temperature + + # If the teacher and student token size is different, pad student logits to match the teacher's. + # This only applies to cases where they share exactly the same vocab and tokenizer just + # that teacher logit is padded for some training efficiency such as + # https://huggingface.co/Qwen/Qwen1.5-72B-Chat/discussions/1#662883f568adf59b07b176d2 + teacher_vocab_size = teacher_weight.shape[0] + student_vocab_size = student_weight.shape[0] + if teacher_vocab_size > student_vocab_size: + pad_size = teacher_vocab_size - student_vocab_size + pad_tensor = torch.zeros( + (*student_logits_chunk.shape[:-1], pad_size), + dtype=student_logits_chunk.dtype, + device=student_logits_chunk.device, + ) + student_logits_chunk = torch.cat([student_logits_chunk, pad_tensor], dim=-1) + + num_valid_tokens = (full_target != ignore_index).sum() + num_valid_tokens = num_valid_tokens.clamp_min(1) # to avoid division by zero + + hard_loss /= num_valid_tokens + + soft_loss = distillation_loss_fn( + student_logits_chunk, teacher_logits_chunk, target=target_chunk, ignore_index=ignore_index, **loss_kwargs + ) + soft_loss /= num_valid_tokens + + loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss + return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk) + + @staticmethod + def forward( + cls, + ctx, + student_input, + student_weight, + teacher_input, + teacher_weight, + target, + student_bias=None, + teacher_bias=None, + chunk_size=1024, + ignore_index=-100, + weight_hard_loss=0.5, + weight_soft_loss=0.5, + beta=0.5, + compute_ce_loss=True, + temperature=1.0, + compiled=True, + return_soft_hard_loss=False, + **loss_kwargs, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + """ + Base class for fused linear layer with distillation loss. + Only need to compute gradients for student model. + + Args: + student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, student_hidden_size). + student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, student_hidden_size). + teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, teacher_hidden_size). + teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, teacher_hidden_size). + target (torch.Tensor): Target truth label tensor. Shape: (batch_size * seq_len). + student_bias (torch.Tensor, optional): Student bias tensor. Shape: (vocab_size,). + teacher_bias (torch.Tensor, optional): Teacher bias tensor. Shape: (vocab_size,). + loss_fn (callable): Loss function to compute the loss on a chunk of input/target. + chunk_size (int): Size of a chunk. + ignore_index (int): Index to ignore for loss computation. + weight_hard_loss (float): Weight for hard/task loss. + weight_soft_loss (float): Weight for soft/distillation loss. + beta (float): Interpolation coefficient between 0 and 1 (default: 0.5). + compute_ce_loss (bool): Whether to compute CE loss. + temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale) + compiled (bool): Whether to use torch compile for chunk accumulation. + return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False. + loss_kwargs (dict): Other possible arguments that a loss function might need + """ + CHUNK_SIZE = chunk_size + grad_weight = torch.zeros_like(student_weight) + grad_inputs = [] + grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None + loss_acc = torch.zeros((), device=student_input.device) + soft_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None + hard_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None + + loss_func_to_call = partial( + LigerFusedLinearDistillationBase._compute_loss, + distillation_loss_fn=cls.distillation_loss_fn, + full_target=target, + ignore_index=ignore_index, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + compute_ce_loss=compute_ce_loss, + temperature=temperature, + beta=beta, + **loss_kwargs, + ) + + def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk): + if student_bias is not None: + ( + (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), + ( + chunk_loss, + ( + chunk_soft_loss, + chunk_hard_loss, + chunk_student_logits, + chunk_teacher_logits, + ), + ), + ) = torch.func.grad_and_value(loss_func_to_call, argnums=(0, 1, 5), has_aux=True)( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias, + teacher_bias, + ) + grad_bias.add_(chunk_grad_bias) + else: + ( + (chunk_grad_input, chunk_grad_weight), + ( + chunk_loss, + ( + chunk_soft_loss, + chunk_hard_loss, + chunk_student_logits, + chunk_teacher_logits, + ), + ), + ) = torch.func.grad_and_value(loss_func_to_call, argnums=(0, 1), has_aux=True)( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias, + teacher_bias, + ) + grad_weight.add_(chunk_grad_weight) + loss_acc.add_(chunk_loss) + if return_soft_hard_loss: + soft_loss_acc.add_(chunk_soft_loss) + hard_loss_acc.add_(chunk_hard_loss) + return chunk_grad_input + + if compiled: + accumulate_chunk = torch.compile(accumulate_chunk) + + num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE) + _student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0) + _teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0) + _target_chunks = torch.chunk(target, chunks=num_chunks, dim=0) + + for student_input_chunk, teacher_input_chunk, target_chunk in zip( + _student_input_chunks, _teacher_input_chunks, _target_chunks + ): + grad_input = accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk) + grad_inputs.append(grad_input) + + ctx.save_for_backward( + torch.cat(grad_inputs, dim=0), + grad_weight, + grad_bias, + ) + if return_soft_hard_loss: + return loss_acc, soft_loss_acc, hard_loss_acc + return loss_acc + + @staticmethod + def backward(ctx, grad_output, *args): + grad_input, grad_weight, grad_bias = ctx.saved_tensors + if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)): + grad_input = grad_input * grad_output + grad_weight = grad_weight * grad_output + grad_bias = grad_bias * grad_output if grad_bias is not None else None + + return grad_input, grad_weight, None, None, None, grad_bias diff --git a/src/liger_kernel/chunked_loss/fused_linear_ppo.py b/src/liger_kernel/chunked_loss/fused_linear_ppo.py new file mode 100755 index 0000000000000000000000000000000000000000..a382cda1b4470db4f217967cd2c4b4293d8e224c --- /dev/null +++ b/src/liger_kernel/chunked_loss/fused_linear_ppo.py @@ -0,0 +1,421 @@ +from abc import abstractmethod +from functools import partial + +import torch +import torch._dynamo.config +import torch.nn.functional as F + + +class LigerFusedLinearPPOBase(torch.autograd.Function): + @abstractmethod + def ppo_loss_fn(*args, **kwargs): + """ + To be extended by subclasses. + """ + raise NotImplementedError("PPO loss function must be implemented.") + + @staticmethod + def forward( + cls, + ctx, + _input, + weight, + selected_token_ids, + attention_mask, + advantages, + bias=None, + ref_per_token_logps=None, + old_per_token_logps=None, + ref_input=None, + ref_weight=None, + ref_bias=None, + epsilon_low=0.2, + epsilon_high=0.2, + beta=0.04, + loss_type="dapo", + max_completion_length=None, + importance_sampling_level="token", + temperature=1.0, + compiled=True, + use_ref_model=False, + chunk_size=1, + sapo_temperature_pos=1.0, + sapo_temperature_neg=1.05, + vllm_is_ratio=None, + delta=None, + use_bias_correction_kl=False, + ): + # TODO: check torch compile matmul + """Chunked forward pass for PPO loss computation. + + Args: + cls: The class + ctx: Context for backward + _input: Input tensor + weight: Weight tensor + selected_token_ids: Selected token ids tensor + attention_mask: Attention mask tensor + advantages: Advantages tensor + bias: Bias tensor + ref_per_token_logps: Reference model log probs per token tensor + old_per_token_logps: Old per token log probabilities tensor + ref_input: Reference model input tensor + ref_weight: Reference model weight tensor + ref_bias: Reference model bias tensor + epsilon_low: Lower bound for clipping the importance sampling ratio + epsilon_high: Upper bound for clipping the importance sampling ratio + beta: Weight for the KL penalty + loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo", "cispo", "sapo") + max_completion_length: Maximum completion length required for "dr_grpo" + importance_sampling_level: Level of importance sampling ("token" or "sequence") + temperature: Temperature for the logits + compiled: Whether to use torch compile + use_ref_model: Whether to use a reference model + chunk_size: Size of chunks for processing in other loss modules + sapo_temperature_pos: Temperature for positive advantages in SAPO + sapo_temperature_neg: Temperature for negative advantages in SAPO + vllm_is_ratio: vLLM importance sampling ratio tensor (batch_size, seq_len) or (batch_size, 1) or None. + Used to correct for distribution mismatch when using vLLM for generation. + """ + if use_ref_model: + assert ref_per_token_logps is not None or ref_input is not None, ( + "If use_ref_model is True, ref_per_token_logps or ref_input must be provided" + ) + if ref_per_token_logps is not None and ref_input is not None: + raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.") + if loss_type == "dr_grpo": + assert max_completion_length is not None, "max_completion_length must be provided for loss_type 'dr_grpo'" + if vllm_is_ratio is not None: + B, T = attention_mask.shape + assert vllm_is_ratio.dim() in (1, 2), ( + f"vllm_is_ratio must be 1D (B,) or 2D (B, T) / (B, 1), got {vllm_is_ratio.dim()}D" + ) + if vllm_is_ratio.dim() == 2: + assert vllm_is_ratio.shape[0] == B and vllm_is_ratio.shape[1] in (1, T), ( + f"vllm_is_ratio shape must be ({B}, 1) or ({B}, {T}), got {tuple(vllm_is_ratio.shape)}" + ) + else: + assert vllm_is_ratio.shape[0] == B, ( + f"vllm_is_ratio shape must be ({B},), got {tuple(vllm_is_ratio.shape)}" + ) + vllm_is_ratio = vllm_is_ratio.unsqueeze(-1) # (B,) -> (B, 1) for broadcasting + # Initialize accumulators + loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32) + grad_weight = torch.zeros_like(weight) # [V, H] + grad_inputs = [] + grad_bias = torch.zeros_like(bias) if bias is not None else None # [V] + aggregated_metrics = [] + + # Create a partial function with fixed arguments + compute_loss = partial( + LigerFusedLinearPPOBase._compute_chunk_loss, + ref_weight=ref_weight, + ref_bias=ref_bias, + full_attention_mask=attention_mask, + epsilon_low=epsilon_low, + epsilon_high=epsilon_high, + beta=beta, + loss_type=loss_type, + max_completion_length=max_completion_length, + importance_sampling_level=importance_sampling_level, + temperature=temperature, + use_ref_model=use_ref_model, + ppo_loss_fn=cls.ppo_loss_fn, + sapo_temperature_pos=sapo_temperature_pos, + sapo_temperature_neg=sapo_temperature_neg, + delta=delta, + use_bias_correction_kl=use_bias_correction_kl, + ) + + def fused_fwd_bwd( + input_chunk, + selected_token_ids_chunk, + attention_mask_chunk, + advantages_chunk, + ref_per_token_logps_chunk, + old_per_token_logps_chunk, + ref_input_chunk, + vllm_is_ratio_chunk, + ): + """Fused forward and backward for a chunk.""" + argnums = (0, 1, 5) if bias is not None else (0, 1) + return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)( + input_chunk, # arg 0 + weight, # arg 1 + selected_token_ids_chunk, # arg 2 + attention_mask_chunk, # arg 3 + advantages_chunk, # arg 4 + bias, # arg 5 + ref_per_token_logps_chunk=ref_per_token_logps_chunk, # arg 6 + old_per_token_logps_chunk=old_per_token_logps_chunk, # arg 7 + ref_input_chunk=ref_input_chunk, # arg 8 + vllm_is_ratio_chunk=vllm_is_ratio_chunk, # arg 9 + ) + + def accumulate_chunk( + input_chunk, + selected_token_ids_chunk, + attention_mask_chunk, + advantages_chunk, + ref_per_token_logps_chunk=None, + old_per_token_logps_chunk=None, + ref_input_chunk=None, + vllm_is_ratio_chunk=None, + ): + (chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd( + input_chunk, + selected_token_ids_chunk, + attention_mask_chunk, + advantages_chunk, + ref_per_token_logps_chunk, + old_per_token_logps_chunk, + ref_input_chunk, + vllm_is_ratio_chunk, + ) + if bias is not None: + grad_bias.add_(chunk_grad_bias[0]) + + # Accumulate gradients and loss + grad_weight.add_(chunk_grad_weight) + grad_inputs.append(chunk_grad_input) + loss_acc.add_(chunk_loss) + # Initialize storage for metrics on first chunk + if len(aggregated_metrics) == 0: + for metric in chunk_metrics: + if metric.ndim == 0: + aggregated_metrics.append(torch.zeros((), device=metric.device)) + else: + aggregated_metrics.append([]) + + # Accumulate metrics + for i, metric in enumerate(chunk_metrics): + if metric.ndim == 0: + aggregated_metrics[i].add_(metric) + else: + aggregated_metrics[i].append(metric) + + if compiled: + # TODO: Figure out what is better to compile here + # accumulate_chunk = torch.compile(accumulate_chunk) + fused_fwd_bwd = torch.compile(fused_fwd_bwd) + + # Process input in chunks based on chunk_size + chunks = max(1, _input.shape[0] // chunk_size) + _input_chunks = torch.chunk(_input, chunks=chunks, dim=0) + _selected_token_ids_chunks = torch.chunk(selected_token_ids, chunks=chunks, dim=0) + _attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0) + _advantages_chunks = torch.chunk(advantages, chunks=chunks, dim=0) + _ref_per_token_logps_chunks = ( + torch.chunk(ref_per_token_logps, chunks=chunks, dim=0) + if use_ref_model and ref_per_token_logps is not None + else [None] * chunks + ) + _old_per_token_logps_chunks = ( + torch.chunk(old_per_token_logps, chunks=chunks, dim=0) + if old_per_token_logps is not None + else [None] * chunks + ) + # if ref_log_probs is not none, then we don't need ref_input to calculate the log probs + _ref_input_chunks = ( + torch.chunk(ref_input, chunks=chunks, dim=0) + if use_ref_model and ref_per_token_logps is None + else [None] * chunks + ) + _vllm_is_ratio_chunks = ( + torch.chunk(vllm_is_ratio, chunks=chunks, dim=0) if vllm_is_ratio is not None else [None] * chunks + ) + + for ( + input_chunk, + selected_token_ids_chunk, + attention_mask_chunk, + advantages_chunk, + ref_per_token_logps_chunk, + old_per_token_logps_chunk, + ref_input_chunk, + vllm_is_ratio_chunk, + ) in zip( + _input_chunks, + _selected_token_ids_chunks, + _attention_mask_chunks, + _advantages_chunks, + _ref_per_token_logps_chunks, + _old_per_token_logps_chunks, + _ref_input_chunks, + _vllm_is_ratio_chunks, + ): + # Mark dynamic dimensions + torch._dynamo.mark_dynamic(input_chunk, 1) + torch._dynamo.mark_dynamic(selected_token_ids_chunk, 1) + torch._dynamo.mark_dynamic(attention_mask_chunk, 1) + if ref_per_token_logps_chunk is not None: + torch._dynamo.mark_dynamic(ref_per_token_logps_chunk, 1) + if ref_input_chunk is not None: + torch._dynamo.mark_dynamic(ref_input_chunk, 1) + if old_per_token_logps_chunk is not None: + torch._dynamo.mark_dynamic(old_per_token_logps_chunk, 1) + if vllm_is_ratio_chunk is not None: + torch._dynamo.mark_dynamic(vllm_is_ratio_chunk, 1) + + accumulate_chunk( + input_chunk, + selected_token_ids_chunk, + attention_mask_chunk, + advantages_chunk, + ref_per_token_logps_chunk, + old_per_token_logps_chunk, + ref_input_chunk, + vllm_is_ratio_chunk, + ) + + # Combine gradients + grad_input = torch.cat(grad_inputs, dim=0) + + # Save for backward + ctx.save_for_backward(grad_input, grad_weight, grad_bias) + + # Finalize metrics + final_metrics = [] + for metric in aggregated_metrics: + if isinstance(metric, list): + final_metrics.append(torch.cat(metric, dim=0)) + else: + final_metrics.append(metric) + + return loss_acc, tuple(final_metrics) + + @staticmethod + def _compute_dapo_normalizer(attention_mask): + """Global active tokens averaged per process.""" + normalizer = attention_mask.to(torch.float32).sum() + world_size = 1 + if torch.distributed.is_available() and torch.distributed.is_initialized(): + import torch.distributed as dist + + normalizer = normalizer.clone() + dist.all_reduce(normalizer, op=dist.ReduceOp.SUM) + world_size = dist.get_world_size() + + normalizer = normalizer / world_size + return torch.clamp(normalizer, min=1.0) + + @staticmethod + def _compute_chunk_loss( + input_chunk, + weight, + selected_token_ids_chunk, + attention_mask_chunk, + advantages_chunk, + bias=None, + ref_per_token_logps_chunk=None, + old_per_token_logps_chunk=None, + ref_input_chunk=None, + vllm_is_ratio_chunk=None, + ref_weight=None, + ref_bias=None, + full_attention_mask=None, + epsilon_low=0.2, + epsilon_high=0.2, + beta=0.04, + loss_type="dapo", + max_completion_length=None, + importance_sampling_level="token", + temperature=1.0, + use_ref_model=False, + ppo_loss_fn=None, + sapo_temperature_pos=1.0, + sapo_temperature_neg=1.05, + delta=None, + use_bias_correction_kl=False, + ): + """Compute loss for a single chunk.""" + # Get policy log probabilities using chunk_forward + log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(input_chunk, weight, bias=bias, temperature=temperature) + + # Get reference log probabilities if needed + ref_log_probs = None + if use_ref_model and ref_per_token_logps_chunk is None: + with torch.no_grad(): + ref_log_probs, _ = LigerFusedLinearPPOBase.chunk_forward( + ref_input_chunk, ref_weight, bias=ref_bias, temperature=temperature + ) + + # Compute chunk loss and metrics using the provided loss function + chunk_loss, chunk_metrics = ppo_loss_fn( + log_probs=log_probs, + selected_token_ids=selected_token_ids_chunk, + attention_mask=attention_mask_chunk, + advantages=advantages_chunk, + full_attention_mask=full_attention_mask, + ref_per_token_logps=ref_per_token_logps_chunk.float() if ref_per_token_logps_chunk is not None else None, + old_per_token_logps=old_per_token_logps_chunk.float() if old_per_token_logps_chunk is not None else None, + ref_log_probs=ref_log_probs, # used when ref_per_token_logps is None + epsilon_low=epsilon_low, + epsilon_high=epsilon_high, + beta=beta, + loss_type=loss_type, + max_completion_length=max_completion_length, + importance_sampling_level=importance_sampling_level, + sapo_temperature_pos=sapo_temperature_pos, + sapo_temperature_neg=sapo_temperature_neg, + vllm_is_ratio=vllm_is_ratio_chunk, + delta=delta, + use_bias_correction_kl=use_bias_correction_kl, + ) + + return chunk_loss, chunk_metrics + + @staticmethod + def chunk_forward(input_chunk, weight, bias=None, temperature=1.0): + """Forward pass computation for a single chunk without explicit reshaping.""" + # Directly compute logits via batched matrix multiplication: [B, T, H] @ [H, V] -> [B, T, V] + logits = torch.matmul(input_chunk, weight.t()) + if bias is not None: + logits = logits + bias # Broadcasts bias to [B, T, V] + if temperature != 1.0: + logits = logits / temperature + + # Compute log probabilities using softmax over the last dimension + log_probs = F.log_softmax(logits.float(), dim=-1) + + return log_probs, logits + + @staticmethod + def backward(ctx, grad_output, *grad_metrics): + """Backward pass for PPO loss.""" + grad_input, grad_weight, grad_bias = ctx.saved_tensors + + if grad_output != 1.0: + grad_input = grad_input * grad_output + grad_weight = grad_weight * grad_output + if grad_bias is not None: + grad_bias = grad_bias * grad_output + + return ( + grad_input, + grad_weight, + None, # grad_selected_token_ids + None, # grad_attention_mask + None, # grad_advantages + grad_bias, + None, # grad_ref_per_token_logps + None, # grad_old_per_token_logps + None, # grad_ref_input + None, # grad_ref_weight + None, # grad_ref_bias + None, # grad_epsilon_low + None, # grad_epsilon_high + None, # grad_beta + None, # grad_loss_type + None, # grad_max_completion_length + None, # grad_importance_sampling_level + None, # grad_temperature + None, # grad_compiled + None, # grad_use_ref_model + None, # grad_chunk_size + None, # grad_sapo_temperature_pos + None, # grad_sapo_temperature_neg + None, # grad_vllm_is_ratio + None, # grad_delta + None, # grad_use_bias_correction_kl + ) diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py new file mode 100755 index 0000000000000000000000000000000000000000..72269be663089da7db09b87c0a9fc494adbeb3e3 --- /dev/null +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -0,0 +1,433 @@ +from abc import abstractmethod +from functools import partial + +import torch + +from torch.nn import functional as F + + +class LigerFusedLinearPreferenceBase(torch.autograd.Function): + @abstractmethod + def preference_loss_fn(*args, **kwargs): + """ + To be extended by subclasses. + """ + raise NotImplementedError("Preference loss function must be implemented.") + + @staticmethod + def forward( + cls, + ctx, + _input, + weight, + target, + bias=None, + chunk_size=1, + ignore_index=-100, + alpha=1.0, + beta=0.1, + compute_nll_loss=True, + nll_target=None, + compiled=True, + use_ref_model=False, + ref_input=None, + ref_weight=None, + ref_bias=None, + average_log_prob=True, + **loss_kwargs, + ): + """ + Base class for fused linear layer with preference loss. + Expects _input to be stacked with chosen and rejected inputs on the batch dimension. + + The mental model is: + + forward() + ├── Loop over chunks + └── compute_loss() + ├── chunk_forward() # Compute logits and log probs + └── prefer_loss() # Calculate preference loss + + Args: + _input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size). + weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). + target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len). + bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + loss_fn (callable): Loss function to compute the loss on a chunk of input/target. + chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs). + ignore_index (int): Index to ignore for loss computation. + alpha (float): Weight for the NLL loss. + beta (float): Weight for the preference loss. + compute_nll_loss (bool): Whether to compute NLL loss. + nll_target (torch.Tensor, optional): Target tensor for NLL loss. Shape: (batch_size, seq_len). If not provided the target is used. + compiled (bool): Whether to use torch compile for chunk accumulation. + use_ref_model (bool): Whether to use a reference model for the alignment loss. + ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). + ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). + average_log_prob (bool): Whether to average log probabilities or to sum them over the completion. + loss_kwargs (dict): Other possible arguments that a loss function might need + """ + # TODO: Tune CHUNK_SIZE to fully utilize the GPU + CHUNK_SIZE = chunk_size + + # Gradients to be accumulated + grad_weight = torch.zeros_like(weight) + grad_chosen_inputs = [] + grad_rejected_inputs = [] + grad_bias = torch.zeros_like(bias) if bias is not None else None + + # Loss to be accumulated + loss_acc = torch.zeros((), device=_input.device) + + # Metrics to be recorded + policy_chosen_logps = [] + policy_rejected_logps = [] + policy_chosen_logits_mean = torch.zeros((), device=_input.device) + policy_rejected_logits_mean = torch.zeros((), device=_input.device) + policy_nll_loss = torch.zeros((), device=_input.device) + aggregated_aux_outputs = [] # aggregated aux outputs from all chunks + + compute_loss = partial( + LigerFusedLinearPreferenceBase._compute_loss, + preference_loss_fn=cls.preference_loss_fn, + ignore_index=ignore_index, + alpha=alpha, + beta=beta, + compute_nll_loss=compute_nll_loss, + full_target=target, + use_ref_model=use_ref_model, + ref_weight=ref_weight, + ref_bias=ref_bias, + full_nll_target=nll_target, + average_log_prob=average_log_prob, + **loss_kwargs, + ) + + def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk): + """ + Fused forward and backward pass for a chunk of input and target. + """ + if bias is not None: + return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 3), has_aux=True)( + input_chunk, + weight, + target_chunk, + bias, + ref_input_chunk=ref_input_chunk, + chosen_nll_target_chunk=chosen_nll_target_chunk, + ) + else: + return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)( + input_chunk, + weight, + target_chunk, + ref_input_chunk=ref_input_chunk, + chosen_nll_target_chunk=chosen_nll_target_chunk, + ) + + def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None, chosen_nll_target_chunk=None): + if bias is not None: + ( + (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), + ( + chunk_loss, + ( + chunk_chosen_logps, + chunk_rejected_logps, + chunk_chosen_logits_mean, + chunk_rejected_logits_mean, + chunk_nll_loss, + *aux_outputs, + ), + ), + ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk) + grad_bias.add_(chunk_grad_bias) # accumulate bias gradient + else: + ( + (chunk_grad_input, chunk_grad_weight), + ( + chunk_loss, + ( + chunk_chosen_logps, + chunk_rejected_logps, + chunk_chosen_logits_mean, + chunk_rejected_logits_mean, + chunk_nll_loss, + *aux_outputs, + ), + ), + ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk) + + # Accumulate gradients + grad_weight.add_(chunk_grad_weight) + grad_chosen_inputs.append(chunk_grad_input[: chosen_target_chunk.shape[0]]) + grad_rejected_inputs.append(chunk_grad_input[chosen_target_chunk.shape[0] :]) + + # Accumulate loss + loss_acc.add_(chunk_loss) + + # Accumulate metrics + policy_chosen_logps.append(chunk_chosen_logps) + policy_rejected_logps.append(chunk_rejected_logps) + policy_chosen_logits_mean.add_(chunk_chosen_logits_mean) + policy_rejected_logits_mean.add_(chunk_rejected_logits_mean) + policy_nll_loss.add_(chunk_nll_loss) + + # aux_outputs + # Initialize storage for aux_outputs + if len(aggregated_aux_outputs) == 0: + for aux in aux_outputs: + if aux.ndim == 0: + aggregated_aux_outputs.append(torch.zeros((), device=aux.device)) + else: + aggregated_aux_outputs.append([]) + + # Process each aux_output + for i, aux in enumerate(aux_outputs): + if aux.ndim == 0: + aggregated_aux_outputs[i].add_(aux) + else: + aggregated_aux_outputs[i].append(aux) + + if compiled: + fused_fwd_bwd = torch.compile(fused_fwd_bwd) + + len_chosen = target.shape[0] // 2 + chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE)) + _chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0) + _chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0) + _rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0) + _rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0) + + if nll_target is not None: + _chosen_nll_target_chunks = torch.chunk(nll_target[:len_chosen], chunks=chunks, dim=0) + + if use_ref_model: + _ref_chosen_input_chunks = torch.chunk(ref_input[:len_chosen], chunks=chunks, dim=0) + _ref_rejected_input_chunks = torch.chunk(ref_input[len_chosen:], chunks=chunks, dim=0) + + for ( + chosen_input_chunk, + rejected_input_chunk, + chosen_target_chunk, + rejected_target_chunk, + ref_chosen_input_chunk, + ref_rejected_input_chunk, + chosen_nll_target_chunk, + ) in zip( + _chosen_input_chunks, + _rejected_input_chunks, + _chosen_target_chunks, + _rejected_target_chunks, + (_ref_chosen_input_chunks if use_ref_model else [None] * len(_chosen_input_chunks)), + (_ref_rejected_input_chunks if use_ref_model else [None] * len(_rejected_input_chunks)), + (_chosen_nll_target_chunks if nll_target is not None else [None] * len(_chosen_input_chunks)), + ): + input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0) + ref_input_chunk = ( + torch.cat([ref_chosen_input_chunk, ref_rejected_input_chunk], dim=0) if use_ref_model else None + ) + target_chunk = torch.cat([chosen_target_chunk, rejected_target_chunk], dim=0) + + # mark input_chunk, target_chunk, and target dimension 1 as dynamic to prevent torch.compile recompilation + torch._dynamo.mark_dynamic(input_chunk, 1) + torch._dynamo.mark_dynamic(target_chunk, 1) + torch._dynamo.mark_dynamic(target, 1) + torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None + torch._dynamo.mark_dynamic(chosen_nll_target_chunk, 1) if nll_target is not None else None + + # accumulate loss, gradients, and metrics + accumulate_chunk(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk) + + # combine grad_chosen_inputs and grad_rejected_inputs + grad_inputs = grad_chosen_inputs + grad_rejected_inputs + policy_chosen_logps = torch.cat(policy_chosen_logps, dim=0) + policy_rejected_logps = torch.cat(policy_rejected_logps, dim=0) + + # Aggregate aux outputs lists into tensors + for i, aux in enumerate(aggregated_aux_outputs): + if isinstance(aux, list): + aggregated_aux_outputs[i] = torch.cat(aux, dim=0) + + ctx.save_for_backward( + torch.cat(grad_inputs, dim=0), + grad_weight, + grad_bias, + ) + return_vars = ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits_mean, + policy_rejected_logits_mean, + policy_nll_loss, + ) + return loss_acc, (*return_vars, *aggregated_aux_outputs) + + @staticmethod + def backward(ctx, *grad_output): + grad_input, grad_weight, grad_bias = ctx.saved_tensors + if torch.ne(grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)): + grad_input = grad_input * grad_output[0][0] + grad_weight = grad_weight * grad_output[0][0] + grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None + + return grad_input, grad_weight, None, grad_bias, None, None, None, None + + @staticmethod + def chunk_forward( + input_chunk, + weight, + target_chunk, + bias=None, + ignore_index=-100, + compute_nll_loss=True, + chosen_nll_target_chunk=None, + average_log_prob=True, + ): + len_chosen_chunk = target_chunk.shape[0] // 2 + logits_chunk = input_chunk @ weight.t() + if bias is not None: + logits_chunk = logits_chunk + bias + log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1) + + chosen_nll_loss = 0.0 + if compute_nll_loss: + nll_labels = ( + chosen_nll_target_chunk if chosen_nll_target_chunk is not None else target_chunk[:len_chosen_chunk] + ) + chosen_nll_loss = F.nll_loss( + log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]), + nll_labels.view(-1), + reduction="sum", + ignore_index=ignore_index, + ) + + loss_mask = target_chunk != ignore_index + label_chunk = torch.where(loss_mask, target_chunk, 0) + + per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1) + if average_log_prob: + log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + log_prob = (per_token_logps * loss_mask).sum(-1) + + chosen_logps = log_prob[:len_chosen_chunk] + rejected_logps = log_prob[len_chosen_chunk:] + + chosen_logits = logits_chunk[:len_chosen_chunk] + rejected_logits = logits_chunk[len_chosen_chunk:] + + return ( + chosen_logps, + rejected_logps, + chosen_logits, + rejected_logits, + chosen_nll_loss, + ) + + @staticmethod + def _compute_loss( + input_chunk, + weight, + target_chunk, + bias=None, + preference_loss_fn=None, + full_target=None, + ignore_index=-100, + alpha=1.0, + beta=0.1, + compute_nll_loss=True, + use_ref_model=False, + ref_input_chunk=None, + ref_weight=None, + ref_bias=None, + full_nll_target=None, + chosen_nll_target_chunk=None, + average_log_prob=True, + **loss_kwargs, + ): + """ + Compute the total loss for a chunk of input and target, while using an alignment/preference loss function. + Args: + preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. + input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size). + weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). + target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length). + bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length). + ignore_index (int): Index to ignore for loss computation. + alpha (float): Weight for the NLL loss. + beta (float): Weight for the preference loss. + compute_nll_loss (bool): Whether to compute NLL loss. + use_ref_model (bool): Whether to use a reference model for the alignment loss. + ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). + ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). + full_nll_target (torch.Tensor, optional): Full target tensor for NLL loss. Shape: (batch_size, sequence_length). + chosen_nll_target_chunk (torch.Tensor, optional): Target tensor for NLL loss. Shape: (chunk_size, sequence_length) If not provided the target_chunk is used. + average_log_prob (bool): Whether to average log probabilities or the sum. + loss_kwargs (dict): Additional arguments for the loss function. + """ + ( + chosen_logps, + rejected_logps, + chosen_logits, + rejected_logits, + chosen_nll_loss, + ) = LigerFusedLinearPreferenceBase.chunk_forward( + input_chunk, + weight, + target_chunk, + bias=bias, + ignore_index=ignore_index, + compute_nll_loss=compute_nll_loss, + chosen_nll_target_chunk=chosen_nll_target_chunk, + average_log_prob=average_log_prob, + ) + if full_nll_target is not None: + chosen_nll_loss = chosen_nll_loss / (full_nll_target[: full_nll_target.shape[0] // 2] != ignore_index).sum() + else: + chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() + + chosen_logits_mean = chosen_logits.sum() / (full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]) + rejected_logits_mean = rejected_logits.sum() / ( + full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] + ) + + if use_ref_model: + with torch.no_grad(): + ( + ref_chosen_logps, + ref_rejected_logps, + _, + _, + _, + ) = LigerFusedLinearPreferenceBase.chunk_forward( + ref_input_chunk, + ref_weight, + target_chunk, + ref_bias, + ignore_index=ignore_index, + compute_nll_loss=False, # We don't need NLL loss for the reference model + chosen_nll_target_chunk=None, + average_log_prob=average_log_prob, + ) + loss_kwargs["ref_chosen_logps"] = ref_chosen_logps + loss_kwargs["ref_rejected_logps"] = ref_rejected_logps + + preference_loss_outputs = preference_loss_fn( + chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs + ) + if isinstance(preference_loss_outputs, tuple): + preference_loss, *aux_outputs = preference_loss_outputs + else: + preference_loss, aux_outputs = preference_loss_outputs, [] + + loss = alpha * chosen_nll_loss + preference_loss + return_vars = ( + chosen_logps, + rejected_logps, + chosen_logits_mean, + rejected_logits_mean, + chosen_nll_loss, + ) + return loss, (*return_vars, *aux_outputs) diff --git a/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py b/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py new file mode 100755 index 0000000000000000000000000000000000000000..73118e493dd9882d4debd5c4ddce8c6b5e6bacaa --- /dev/null +++ b/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py @@ -0,0 +1,341 @@ +from abc import abstractmethod +from functools import partial + +import torch + +from torch.nn import functional as F + + +class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function): + @abstractmethod + def preference_loss_fn(*args, **kwargs): + """ + To be extended by subclasses. + """ + raise NotImplementedError("Preference loss function must be implemented.") + + @staticmethod + def forward( + cls, + ctx, + _input, + weight, + target, + preference_labels, + bias=None, + chunk_size=1, + ignore_index=-100, + compiled=True, + use_ref_model=False, + ref_input=None, + ref_weight=None, + ref_bias=None, + average_log_prob=False, + **loss_kwargs, + ): + """ + Base class for fused linear layer with unpaired preference loss like KTO + Expects _input to be stacked with chosen and rejected inputs on the batch dimension. + + The mental model is: + + forward() + ├── Loop over chunks + └── compute_loss() + ├── chunk_forward() # Compute logits and log probs + └── prefer_loss() # Calculate preference loss + + Args: + _input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size). + weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). + target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len). + bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + loss_fn (callable): Loss function to compute the loss on a chunk of input/target. + chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs). + ignore_index (int): Index to ignore for loss computation. + beta (float): Weight for the preference loss. + compiled (bool): Whether to use torch compile for chunk accumulation. + use_ref_model (bool): Whether to use a reference model for the alignment loss. + preference_labels (torch.Tensor): Boolean tensor indicating chosen (True) vs rejected (False) examples. + Shape: (batch_size,). + ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). + ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). + average_log_prob (bool): Whether to average the log probability per non-masked token. + loss_kwargs (dict): Other possible arguments that a loss function might need + """ + # TODO: Tune CHUNK_SIZE to fully utilize the GPU + CHUNK_SIZE = chunk_size + + # Gradients to be accumulated + grad_inputs = [] + grad_weight = torch.zeros_like(weight) + grad_bias = torch.zeros_like(bias) if bias is not None else None + + # Loss to be accumulated + loss_acc = torch.zeros((), device=_input.device) + + # Metrics to be recorded + chosen_logps_sum = torch.zeros((), device=_input.device) + rejected_logps_sum = torch.zeros((), device=_input.device) + chosen_logits_sum = torch.zeros((), device=_input.device) + rejected_logits_sum = torch.zeros((), device=_input.device) + aggregated_aux_outputs = [] + + compute_loss = partial( + LigerFusedLinearUnpairedPreferenceBase._compute_loss, + preference_loss_fn=cls.preference_loss_fn, + full_target=target, + ignore_index=ignore_index, + use_ref_model=use_ref_model, + ref_weight=ref_weight, + ref_bias=ref_bias, + average_log_prob=average_log_prob, + **loss_kwargs, + ) + + def fused_fwd_bwd(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk): + """ + Fused forward and backward pass for a chunk of input and target. + """ + argnums = (0, 1, 4) if bias is not None else (0, 1) + return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)( + input_chunk, + weight, + target_chunk, + preference_labels_chunk, + bias, + ref_input_chunk=ref_input_chunk, + ) + + def accumulate_chunk( + input_chunk, + target_chunk, + preference_labels_chunk=None, + ref_input_chunk=None, + ): + ( + (chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), + ( + chunk_loss, + ( + chunk_chosen_logps_sum, + chunk_rejected_logps_sum, + chunk_chosen_logits_sum, + chunk_rejected_logits_sum, + *aux_outputs, + ), + ), + ) = fused_fwd_bwd(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk) + if bias is not None: + grad_bias.add_(chunk_grad_bias[0]) # accumulate bias gradient + + # Accumulate gradients + grad_weight.add_(chunk_grad_weight) + grad_inputs.append(chunk_grad_input) + + # Accumulate loss + loss_acc.add_(chunk_loss) + + # Accumulate metrics + chosen_logps_sum.add_(chunk_chosen_logps_sum) + rejected_logps_sum.add_(chunk_rejected_logps_sum) + chosen_logits_sum.add_(chunk_chosen_logits_sum) + rejected_logits_sum.add_(chunk_rejected_logits_sum) + + # aux_outputs + # Initialize storage for aux_outputs + if len(aggregated_aux_outputs) == 0: + for aux in aux_outputs: + aggregated_aux_outputs.append(torch.zeros((), device=aux.device)) + + # Process each aux_output + for i, aux in enumerate(aux_outputs): + if aux.ndim == 0: + aggregated_aux_outputs[i].add_(aux) + + if compiled: + fused_fwd_bwd = torch.compile(fused_fwd_bwd) + + # When not paired, use labels to separate chosen and rejected + assert preference_labels is not None, "preference_labels must be provided for unpaired preference loss" + + chunks = max(1, _input.shape[0] // CHUNK_SIZE) + _input_chunks = torch.chunk(_input, chunks=chunks, dim=0) + _target_chunks = torch.chunk(target, chunks=chunks, dim=0) + _preference_labels_chunks = torch.chunk(preference_labels, chunks=chunks, dim=0) + + if use_ref_model: + _ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0) + + for ( + input_chunk, + target_chunk, + ref_input_chunk, + preference_labels_chunk, + ) in zip( + _input_chunks, + _target_chunks, + (_ref_input_chunks if use_ref_model else [None] * len(_input_chunks)), + _preference_labels_chunks, + ): + # mark input_chunk, target_chunk, and target dimension 1 (sequence length) as dynamic to prevent torch.compile recompilation + torch._dynamo.mark_dynamic(input_chunk, 1) + torch._dynamo.mark_dynamic(target_chunk, 1) + torch._dynamo.mark_dynamic(target, 1) + torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None + torch._dynamo.mark_dynamic(preference_labels_chunk, 1) + + # accumulate loss, gradients, and metrics + accumulate_chunk(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk) + + # Aggregate aux outputs lists into tensors + for i, aux in enumerate(aggregated_aux_outputs): + if isinstance(aux, list): + aggregated_aux_outputs[i] = torch.cat(aux, dim=0) + + ctx.save_for_backward( + torch.cat(grad_inputs, dim=0), + grad_weight, + grad_bias, + ) + + return_vars = ( + chosen_logps_sum, + rejected_logps_sum, + chosen_logits_sum, + rejected_logits_sum, + ) + + return loss_acc, (*return_vars, *aggregated_aux_outputs) + + @staticmethod + def backward(ctx, *grad_output): + grad_input, grad_weight, grad_bias = ctx.saved_tensors + if torch.ne(grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)): + grad_input = grad_input * grad_output[0][0] + grad_weight = grad_weight * grad_output[0][0] + grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None + + return grad_input, grad_weight, None, None, grad_bias + + @staticmethod + def chunk_forward( + input_chunk, + weight, + target_chunk, + preference_labels_chunk, + bias=None, + ignore_index=-100, + average_log_prob=False, + ): + logits_chunk = input_chunk @ weight.t() + if bias is not None: + logits_chunk = logits_chunk + bias + log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1) + loss_mask_chunk = target_chunk != ignore_index + label_chunk = torch.where(loss_mask_chunk, target_chunk, 0) + + per_token_logps_chunk = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1) + if average_log_prob: + log_probs = (per_token_logps_chunk * loss_mask_chunk).sum(-1) / loss_mask_chunk.sum(-1) + else: + log_probs = (per_token_logps_chunk * loss_mask_chunk).sum(-1) + + chosen_logps_sum = (log_probs * preference_labels_chunk.unsqueeze(1)).sum() + rejected_logps_sum = (log_probs * (~preference_labels_chunk).unsqueeze(1)).sum() + + chosen_logits_sum = (logits_chunk * preference_labels_chunk.unsqueeze(1)).sum() + rejected_logits_sum = (logits_chunk * (~preference_labels_chunk).unsqueeze(1)).sum() + + return ( + log_probs, + chosen_logps_sum, + rejected_logps_sum, + chosen_logits_sum, + rejected_logits_sum, + ) + + @staticmethod + def _compute_loss( + input_chunk, + weight, + target_chunk, + preference_labels_chunk, + bias=None, + preference_loss_fn=None, + full_target=None, + ignore_index=-100, + use_ref_model=False, + ref_input_chunk=None, + ref_weight=None, + ref_bias=None, + average_log_prob=False, + **loss_kwargs, + ): + """ + Compute the total loss for a chunk of input and target, while using an alignment/preference loss function. + Args: + preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. + input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size). + weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). + target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length). + bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length). + ignore_index (int): Index to ignore for loss computation. + use_ref_model (bool): Whether to use a reference model for the alignment loss. + ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). + ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). + average_log_prob (bool): Whether to average the log probability per non-masked token. + loss_kwargs (dict): Additional arguments for the loss function. + """ + ( + log_prob_chunk, + chosen_logps_sum, + rejected_logps_sum, + chosen_logits_sum, + rejected_logits_sum, + ) = LigerFusedLinearUnpairedPreferenceBase.chunk_forward( + input_chunk, + weight, + target_chunk, + preference_labels_chunk, + bias=bias, + ignore_index=ignore_index, + average_log_prob=average_log_prob, + ) + + if use_ref_model: + with torch.no_grad(): + ( + ref_log_prob_chunk, + _, + _, + _, + _, + ) = LigerFusedLinearUnpairedPreferenceBase.chunk_forward( + ref_input_chunk, + ref_weight, + target_chunk, + preference_labels_chunk, + ref_bias, + ignore_index=ignore_index, + average_log_prob=average_log_prob, + ) + loss_kwargs["ref_log_prob_chunk"] = ref_log_prob_chunk + + preference_loss_outputs = preference_loss_fn( + log_prob_chunk, preference_labels_chunk, full_target, **loss_kwargs + ) + if isinstance(preference_loss_outputs, tuple): + preference_loss_chunk, *aux_outputs = preference_loss_outputs + else: + preference_loss_chunk, aux_outputs = preference_loss_outputs, [] + + return_vars = ( + chosen_logps_sum, + rejected_logps_sum, + chosen_logits_sum, + rejected_logits_sum, + ) + + return preference_loss_chunk, (*return_vars, *aux_outputs) diff --git a/src/liger_kernel/chunked_loss/grpo_loss.py b/src/liger_kernel/chunked_loss/grpo_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..f05cc874468028274a3f8f3b53698fbe1a9ff3e1 --- /dev/null +++ b/src/liger_kernel/chunked_loss/grpo_loss.py @@ -0,0 +1,462 @@ +from typing import Optional + +import torch + +from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase + + +def k3_loss_fn(log_p, log_q): + # computes k3 estimate of KL[q, p] + # ref: http://joschu.net/blog/kl-approx.html + return torch.exp(log_p - log_q) - (log_p - log_q) - 1.0 + + +def sapo_loss_fn(importance_ratio: torch.Tensor, temperature: float) -> torch.Tensor: + """SAPO (Soft Adaptive Policy Optimization) loss function. + + Replaces hard clipping with a smooth, temperature-controlled gate that + adaptively attenuates off-policy updates while preserving useful learning signals. + + Reference: https://huggingface.co/papers/2511.20347 + TRL implementation: https://github.com/huggingface/trl/blob/1bd2a52ec2d8344050af736d60cdc735181ae4b8/trl/trainer/grpo_trainer.py#L1913 + + Args: + importance_ratio: The importance sampling ratio (pi_theta / pi_old). + temperature: Temperature parameter controlling the softness of the gate. + + Returns: + The SAPO loss value. + """ + if temperature <= 0: + raise ValueError("sapo_temperature must be > 0.") + sigmoid_input = temperature * (importance_ratio - 1) + sigmoid_smoothed_loss = torch.sigmoid(sigmoid_input) + return sigmoid_smoothed_loss * 4 / temperature + + +def clip_coef_fn(coef, epsilon_low, epsilon_high, loss_type): + if loss_type == "cispo": + # CISPO: clip and detach the importance weights + upper_bound = epsilon_high + lower_bound = None + clipped_coef = torch.clamp(coef, lower_bound, upper_bound).detach() + is_lower_clipped = False + is_upper_clipped = coef > upper_bound + elif loss_type == "sapo": + # SAPO doesn't use clipping metrics + clipped_coef = None + is_lower_clipped = torch.zeros_like(coef, dtype=torch.bool) + is_upper_clipped = torch.zeros_like(coef, dtype=torch.bool) + else: + upper_bound = 1 + epsilon_high + lower_bound = 1 - epsilon_low + clipped_coef = torch.clamp(coef, lower_bound, upper_bound) + is_lower_clipped = coef < lower_bound + is_upper_clipped = coef > upper_bound + return clipped_coef, is_lower_clipped, is_upper_clipped + + +class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase): + @staticmethod + def ppo_loss_fn( + log_probs, + selected_token_ids, + attention_mask, + advantages, + full_attention_mask, + ref_per_token_logps=None, # shape: [chunk_size, seq_len] + old_per_token_logps=None, + ref_log_probs=None, # used when ref_per_token_logps is None (shape: [chunk_size, seq_len, vocab_size]) + epsilon_low=0.2, + epsilon_high=0.2, + beta=0.04, + loss_type="dapo", # ["grpo", "bnpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo"] + max_completion_length=None, # Required for dr_grpo + importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO + sapo_temperature_pos=1.0, # Temperature for positive advantages in SAPO + sapo_temperature_neg=1.05, # Temperature for negative advantages in SAPO + vllm_is_ratio=None, # vLLM importance sampling ratio (chunk_size, seq_len) or (chunk_size, 1) or None + delta=None, # Upper clamp for two-sided clipping (INTELLECT-2) + use_bias_correction_kl=False, # Importance-sampling-corrected KL (DeepSeek-V3.2) + **kwargs, + ): + """GRPO Loss Function matching GRPOTrainer implementation.""" + # Validate sequence-level + loss_type combinations + if importance_sampling_level == "sequence" and loss_type in ("cispo", "sapo"): + raise ValueError( + f"Sequence-level importance sampling is not supported for loss_type='{loss_type}'. " + f"Use importance_sampling_level='token' instead." + ) + + per_token_logps = log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze( + -1 + ) # (batch_size, seq_len) + + # Get reference model probabilities + if ref_per_token_logps is None: + if ref_log_probs is not None: + with torch.no_grad(): + ref_per_token_logps = ref_log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze( + -1 + ) + else: + ref_per_token_logps = per_token_logps.detach() + + # Compute policy gradient loss with importance sampling ratio + old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach() + log_ratio = per_token_logps - old_per_token_logps + + if importance_sampling_level == "token": + log_importance_weights = log_ratio + elif importance_sampling_level == "sequence": + log_importance_weights = (log_ratio * attention_mask).sum(-1) / attention_mask.sum(-1).clamp(min=1.0) + log_importance_weights = log_importance_weights.unsqueeze(-1) + else: + raise ValueError( + f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' " + "and 'sequence'." + ) + + # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on + # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1) + coef_1 = torch.exp(log_importance_weights) + coef_2, is_lower_clipped, is_upper_clipped = clip_coef_fn(coef_1, epsilon_low, epsilon_high, loss_type) + if loss_type == "cispo": + # CISPO: clip and detach the importance weights, multiply by log probs + # Reference: https://github.com/huggingface/trl/blob/035c3ff151b953ca72cdfe0ee966bc1469a26fde/trl/trainer/grpo_trainer.py#L2030 + per_token_loss = -coef_2 * advantages.unsqueeze(1) * per_token_logps + elif loss_type == "sapo": + # SAPO: Soft Adaptive Policy Optimization + # Uses sigmoid-based soft gating instead of hard clipping + # Reference: https://huggingface.co/papers/2511.20347 + # TRL implementation: https://github.com/huggingface/trl/blob/1bd2a52ec2d8344050af736d60cdc735181ae4b8/trl/trainer/grpo_trainer.py#L2037-L2046 + per_token_loss = torch.empty_like(coef_1) + # Expand advantages to match coef_1 shape for masking + advantages_expanded = advantages.unsqueeze(1).expand_as(coef_1) + positive_advantages_mask = advantages_expanded > 0 + + # Apply different temperatures based on advantage sign + per_token_loss[positive_advantages_mask] = sapo_loss_fn( + coef_1[positive_advantages_mask], sapo_temperature_pos + ) + per_token_loss[~positive_advantages_mask] = sapo_loss_fn( + coef_1[~positive_advantages_mask], sapo_temperature_neg + ) + per_token_loss = -per_token_loss * advantages_expanded + else: + # Apply delta (two-sided clipping from INTELLECT-2) to coef_1 + if delta is not None: + coef_1 = torch.clamp(coef_1, max=delta) + per_token_loss1 = coef_1 * advantages.unsqueeze(1) + per_token_loss2 = coef_2 * advantages.unsqueeze(1) + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + + # Apply vLLM importance sampling correction BEFORE adding KL penalty + if vllm_is_ratio is not None: + per_token_loss = per_token_loss * vllm_is_ratio + + if beta != 0.0: + # Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps]) + kl_div = k3_loss_fn(ref_per_token_logps, per_token_logps) + if use_bias_correction_kl: + # Importance-sampling-corrected KL (DeepSeek-V3.2): kl *= token-level coef_1 + token_coef_1 = torch.exp(per_token_logps - old_per_token_logps) + kl_div = kl_div * token_coef_1 + # Combine losses + per_token_loss = per_token_loss + beta * kl_div + + # Note: We normalize by the number of tokens in the batch (using full_attention_mask), + # which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1) + # and TRL GRPO implementation + # (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966) + if loss_type == "grpo" or loss_type == "sapo": + # Average per-sequence loss (SAPO uses same normalization as GRPO) + loss = ( + (per_token_loss * attention_mask).sum(-1) / torch.clamp(attention_mask.sum(-1), min=1.0) + ).sum() / full_attention_mask.shape[0] + elif loss_type == "bnpo": + # Batch Normalized Per-token loss (original implementation) + loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0) + elif loss_type == "dr_grpo": + # Dimension-Reduced GRPO (normalize by batch_size * max_completion_length) + if max_completion_length is None: + raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'") + loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length) + elif loss_type == "dapo" or loss_type == "cispo": + loss_normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(full_attention_mask) + loss = (per_token_loss * attention_mask).sum() / loss_normalizer + elif loss_type == "luspo": + # LUSPO: loss = (per_token_loss * mask.sum(1, keepdim=True)).mean() + # Reformulated as: sum_i(sum_j(per_token_loss_ij) * seq_len_i) / numel + # to avoid (B,T) * (B,1) broadcast which amplifies torch.compile differences. + seq_lens = attention_mask.sum(-1) # (chunk_B,) + per_seq_sum = per_token_loss.sum(-1) # (chunk_B,) + weighted = per_seq_sum * seq_lens # (chunk_B,) + if importance_sampling_level == "sequence" and beta == 0.0: + # per_token_loss stays (B, 1), so .mean() divides by B + loss = weighted.sum() / full_attention_mask.shape[0] + else: + # per_token_loss is (B, T), .mean() divides by B*T + loss = weighted.sum() / (full_attention_mask.shape[0] * full_attention_mask.shape[1]) + else: + raise ValueError(f"Unknown loss type: {loss_type}") + + # Calculate metrics + metrics = [] + if beta != 0.0: + metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))) + + # Adjust clipping metric calculation based on importance sampling level + if importance_sampling_level == "token": + is_clipped = (is_lower_clipped & (advantages.unsqueeze(1) < 0)) | ( + is_upper_clipped & (advantages.unsqueeze(1) > 0) + ) + else: # sequence level + # For sequence level, coef_1 is shape (B, 1), advantages is shape (B,) + is_clipped = (is_lower_clipped & (advantages.unsqueeze(1) < 0)) | ( + is_upper_clipped & (advantages.unsqueeze(1) > 0) + ) + is_clipped = is_clipped.expand_as(attention_mask) + + metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)) + return loss, metrics + + @classmethod + def forward( + cls, + ctx, + _input, + weight, + selected_token_ids, + attention_mask, + advantages, + bias=None, + ref_per_token_logps=None, + old_per_token_logps=None, + ref_input=None, + ref_weight=None, + ref_bias=None, + beta=0.04, + epsilon_low=0.2, + epsilon_high=0.2, + loss_type="dapo", + max_completion_length=None, + importance_sampling_level="token", + sapo_temperature_pos=1.0, + sapo_temperature_neg=1.05, + temperature=1.0, + compiled=True, + use_ref_model=True, + chunk_size=1, + vllm_is_ratio=None, + delta=None, + use_bias_correction_kl=False, + ): + """ + Fused linear layer with GRPO loss. + Args: + _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size) + weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size) + selected_token_ids (torch.Tensor): Selected token ids tensor. Shape: (batch_size, seq_len) + attention_mask (torch.Tensor): Attention mask tensor. Shape: (batch_size, seq_len) + advantages (torch.Tensor): Advantages tensor. Shape: (batch_size,) + bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,) + ref_per_token_logps: Reference model log probs per token tensor. Shape:(batch_size, seq_len) + ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size) + ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size) + ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,) + beta (float): Weight for the KL penalty + loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo"). + Defaults to "dapo". + max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None. + importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token". + sapo_temperature_pos (float): Temperature for positive advantages in SAPO. Defaults to 1.0. + sapo_temperature_neg (float): Temperature for negative advantages in SAPO. Defaults to 1.05. + temperature (float): Temperature for the logits + compiled (bool): Whether to use torch compile + use_ref_model (bool): Whether to use a reference model + chunk_size (int): Size of chunks for processing. + vllm_is_ratio (torch.Tensor, optional): vLLM importance sampling ratio (batch_size, seq_len) or (batch_size, 1) or None. + Used to correct for distribution mismatch when using vLLM for generation. + Returns: + torch.Tensor: Computed loss + """ + # Validate before entering torch.compile boundary + if importance_sampling_level == "sequence" and loss_type in ("cispo", "sapo"): + raise ValueError( + f"Sequence-level importance sampling is not supported for loss_type='{loss_type}'. " + f"Use importance_sampling_level='token' instead." + ) + + return super().forward( + cls=cls, + ctx=ctx, + _input=_input, + weight=weight, + selected_token_ids=selected_token_ids, + attention_mask=attention_mask, + advantages=advantages, + bias=bias, + ref_per_token_logps=ref_per_token_logps, + old_per_token_logps=old_per_token_logps, + ref_input=ref_input, + ref_weight=ref_weight, + ref_bias=ref_bias, + beta=beta, + epsilon_low=epsilon_low, + epsilon_high=epsilon_high, + loss_type=loss_type, + max_completion_length=max_completion_length, + temperature=temperature, + compiled=compiled, + use_ref_model=use_ref_model, + chunk_size=chunk_size, + importance_sampling_level=importance_sampling_level, + sapo_temperature_pos=sapo_temperature_pos, + sapo_temperature_neg=sapo_temperature_neg, + vllm_is_ratio=vllm_is_ratio, + delta=delta, + use_bias_correction_kl=use_bias_correction_kl, + ) + + @staticmethod + def backward(ctx, grad_output, *grad_metrics): + """Backward pass for GRPO loss. + + Args: + grad_output: Gradient of the loss (scalar) + grad_metrics: Gradients of the metrics (not used in backward computation) + """ + grads = LigerFusedLinearPPOBase.backward(ctx, grad_output) + return ( + *grads[ + :6 + ], # grad_input, grad_weight, grad_selected_token_ids, grad_attention_mask, grad_advantages, grad_bias + None, # grad_ref_per_token_logps + None, # grad_old_per_token_logps + None, # grad_ref_input + None, # grad_ref_weight + None, # grad_ref_bias + None, # grad_beta + None, # grad_epsilon_low + None, # grad_epsilon_high + None, # grad_loss_type (string, not differentiable) + None, # grad_max_completion_length (int, not differentiable) + None, # grad_importance_sampling_level (string, not differentiable) + None, # grad_sapo_temperature_pos (float, not differentiable) + None, # grad_sapo_temperature_neg (float, not differentiable) + None, # grad_temperature + None, # grad_compiled + None, # grad_use_ref_model + None, # grad_chunk_size + None, # grad_vllm_is_ratio + None, # grad_delta + None, # grad_use_bias_correction_kl + ) + + +class LigerFusedLinearGRPOLoss(torch.nn.Module): + """Fused linear layer with GRPO loss.""" + + def __init__( + self, + beta: float = 0.04, + compiled: bool = True, + use_ref_model: bool = True, + chunk_size: int = 1, + epsilon_low: float = 0.2, + epsilon_high: float = 0.2, + loss_type: str = "dapo", + max_completion_length: Optional[int] = None, + importance_sampling_level: str = "token", + sapo_temperature_pos: float = 1.0, + sapo_temperature_neg: float = 1.05, + temperature: float = 1.0, + delta: Optional[float] = None, + use_bias_correction_kl: bool = False, + ): + """ + Args: + beta (float): Weight for the KL penalty. + compiled (bool): Whether to use torch compile. + use_ref_model (bool): Whether to use a reference model. + chunk_size (int): Size of chunks for processing. + epsilon_low (float): Lower bound for the importance sampling ratio. + epsilon_high (float): Upper bound for the importance sampling ratio. + loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo"). + Defaults to "dapo". For "cispo", epsilon_high is typically larger (e.g. 5.0) and + epsilon_low is unused. For "sapo", uses soft gating instead of hard clipping. + max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None. + importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token". + sapo_temperature_pos (float): Temperature for positive advantages in SAPO. Defaults to 1.0. + sapo_temperature_neg (float): Temperature for negative advantages in SAPO. Defaults to 1.05. + temperature (float): Temperature for the logits. + delta (float, optional): Upper clamp for two-sided clipping (INTELLECT-2). None means disabled. + use_bias_correction_kl (bool): If True, multiply KL by importance sampling ratio (DeepSeek-V3.2). + """ + super().__init__() + # Validate SAPO temperatures to prevent division by zero or numerical instability + if sapo_temperature_pos <= 0: + raise ValueError(f"sapo_temperature_pos must be positive, got {sapo_temperature_pos}") + if sapo_temperature_neg <= 0: + raise ValueError(f"sapo_temperature_neg must be positive, got {sapo_temperature_neg}") + if delta is not None and delta <= 0: + raise ValueError(f"delta must be positive, got {delta}") + self.beta = beta + self.compiled = compiled + self.use_ref_model = use_ref_model + self.chunk_size = chunk_size + self.epsilon_low = epsilon_low + self.epsilon_high = epsilon_high + self.loss_type = loss_type + self.max_completion_length = max_completion_length + self.importance_sampling_level = importance_sampling_level + self.sapo_temperature_pos = sapo_temperature_pos + self.sapo_temperature_neg = sapo_temperature_neg + self.temperature = temperature + self.delta = delta + self.use_bias_correction_kl = use_bias_correction_kl + + def forward( + self, + _input, + lin_weight, + selected_token_ids, + attention_mask, + advantages, + bias=None, + ref_per_token_logps=None, + old_per_token_logps=None, + ref_input=None, + ref_weight=None, + ref_bias=None, + vllm_is_ratio=None, + ): + return LigerFusedLinearGRPOFunction.apply( + _input, + lin_weight, + selected_token_ids, + attention_mask, + advantages, + bias, + ref_per_token_logps, + old_per_token_logps, + ref_input, + ref_weight, + ref_bias, + self.beta, + self.epsilon_low, + self.epsilon_high, + self.loss_type, + self.max_completion_length, + self.importance_sampling_level, + self.sapo_temperature_pos, + self.sapo_temperature_neg, + self.temperature, + self.compiled, + self.use_ref_model, + self.chunk_size, + vllm_is_ratio, + self.delta, + self.use_bias_correction_kl, + ) diff --git a/src/liger_kernel/chunked_loss/jsd_loss.py b/src/liger_kernel/chunked_loss/jsd_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..64cc75a40dc7272d18271a465b5286087514c117 --- /dev/null +++ b/src/liger_kernel/chunked_loss/jsd_loss.py @@ -0,0 +1,215 @@ +import math + +from typing import Tuple +from typing import Union + +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase + + +class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase): + @staticmethod + def distillation_loss_fn(student_logits, teacher_logits, beta=0.5, target=None, ignore_index=-100): + """ + Compute JSD loss (Jensen-Shannon Divergence Loss). + Args: + student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,). + teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,). + beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`. + target (torch.Tensor): Target labels for masking. Shape: (chunk_size,). + ignore_index (int): Index to ignore in loss computation. + Returns: + torch.Tensor: Jensen-Shannon Divergence loss + Note: + - Uses reduction="none" to preserve per-token losses for masking + - KL divergence requires summing over vocab dimension (not mean) + - Masking excludes padding/prompt tokens from loss computation + """ + student_log_probs = F.log_softmax(student_logits, dim=-1) + teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) + + if beta == 0: + jsd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True) + elif beta == 1: + jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True) + else: + # Compute probabilities (only required for mean calculation) + log_mean_probs = torch.logsumexp( + torch.stack([student_log_probs + math.log(1 - beta), teacher_log_probs + math.log(beta)], dim=0), dim=0 + ) + + student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="none", log_target=True) + teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="none", log_target=True) + + # JSD is the weighted average of the KL divergences + jsd_loss = beta * teacher_kl + (1 - beta) * student_kl + + # Sum over vocab dimension (KL divergence definition) + jsd_loss = jsd_loss.sum(dim=-1) # (chunk_size,) + + # Apply ignore_index mask + if target is not None: + mask = target != ignore_index + jsd_loss = jsd_loss.masked_fill(~mask, 0.0) + + return jsd_loss.sum() + + @classmethod + def forward( + cls, + ctx, + student_input: torch.Tensor, + student_weight: torch.Tensor, + teacher_input: torch.Tensor, + teacher_weight: torch.Tensor, + true_labels: torch.LongTensor, + student_bias: torch.Tensor, + teacher_bias: torch.Tensor, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + compiled: bool = True, + chunk_size: int = 1024, + return_soft_hard_loss: bool = False, + ): + """ + Fused linear layer with JSD distillation loss. + Args: + student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, hidden_size_student) + student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, hidden_size_student) + teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, hidden_size_teacher) + teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, hidden_size_teacher) + true_labels (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,) + weight_hard_loss (float): Weight for hard loss. + weight_soft_loss (float): Weight for soft loss. + beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`. + ignore_index (int): Index to ignore in loss computation + temperature (float): Temperature for softening/sharpening distributions + compiled (bool): Whether to use torch compile + chunk_size (int): Size of chunks for processing. + return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False. + Returns: + torch.Tensor: Computed loss, or tuple (loss, soft_loss, hard_loss) if return_soft_hard_loss=True + """ + return super().forward( + cls=cls, + ctx=ctx, + student_input=student_input, + student_weight=student_weight, + teacher_input=teacher_input, + teacher_weight=teacher_weight, + target=true_labels, + student_bias=student_bias, + teacher_bias=teacher_bias, + chunk_size=chunk_size, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + beta=beta, + ignore_index=ignore_index, + temperature=temperature, + compiled=compiled, + return_soft_hard_loss=return_soft_hard_loss, + ) + + @staticmethod + def backward(ctx, grad_output, *args): + grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6] + + return ( + *grads, + None, # teacher_bias + None, # weight_hard_loss + None, # weight_soft_loss + None, # beta + None, # ignore_index + None, # temperature + None, # compiled + None, # chunk_size + None, # return_soft_hard_loss + ) + + +class LigerFusedLinearJSDLoss(torch.nn.Module): + """ + Fused linear layer with JSD distillation loss. + """ + + def __init__( + self, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + compiled: bool = True, + chunk_size: int = 1024, + return_soft_hard_loss: bool = False, + ): + """ + Args: + weight_hard_loss (float): Weight for hard loss. + weight_soft_loss (float): Weight for soft loss. + ignore_index (int): Index to ignore in the loss + temperature (float): Temperature for softening distributions + compiled (bool): Whether to use torch compile + beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`. + chunk_size (int): Size of chunks for processing. + return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False. + """ + super().__init__() + assert temperature != 0, "Temperature cannot be 0." + self.weight_hard_loss = weight_hard_loss + self.weight_soft_loss = weight_soft_loss + self.ignore_index = ignore_index + self.temperature = temperature + self.compiled = compiled + self.beta = beta + self.chunk_size = chunk_size + self.return_soft_hard_loss = return_soft_hard_loss + + def forward( + self, + student_input: torch.Tensor, + student_weight: torch.Tensor, + teacher_input: torch.Tensor, + teacher_weight: torch.Tensor, + true_labels: torch.LongTensor, + student_bias: torch.Tensor = None, + teacher_bias: torch.Tensor = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + """ + Compute the JSD distillation loss. + + Args: + student_input (torch.Tensor): Student input tensor + student_weight (torch.Tensor): Student weight tensor + teacher_input (torch.Tensor): Teacher input tensor + teacher_weight (torch.Tensor): Teacher weight tensor + true_labels (torch.LongTensor): Target labels tensor + + Returns: + torch.Tensor or Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + If return_soft_hard_loss is False: Computed combined loss + If return_soft_hard_loss is True: Tuple of (combined_loss, soft_loss, hard_loss) + """ + return LigerFusedLinearJSDFunction.apply( + student_input, + student_weight, + teacher_input, + teacher_weight, + true_labels, + student_bias, + teacher_bias, + self.weight_hard_loss, + self.weight_soft_loss, + self.beta, + self.ignore_index, + self.temperature, + self.compiled, + self.chunk_size, + self.return_soft_hard_loss, + ) diff --git a/src/liger_kernel/chunked_loss/kto_loss.py b/src/liger_kernel/chunked_loss/kto_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..e7b4d503368815de8f59ce7353b0adef9c901705 --- /dev/null +++ b/src/liger_kernel/chunked_loss/kto_loss.py @@ -0,0 +1,210 @@ +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss.fused_linear_unpaired_preference import LigerFusedLinearUnpairedPreferenceBase + + +class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase): + @staticmethod + def preference_loss_fn( + log_prob_chunk, + preference_labels_chunk, + full_target, + ref_log_prob_chunk=None, + beta=0.1, + kl=None, + ): + """ + Implements the Kahneman-Tversky Optimization (KTO) loss function. + Paper: "KTO: Model Alignment as Prospect Theory-Guided Optimization" + https://arxiv.org/abs/2402.01306 + + KTO loss is inspired by prospect theory (https://en.wikipedia.org/wiki/Prospect_theory) + from behavioral economics, which models how humans make decisions under uncertainty. + The loss function is asymmetric, treating gains and losses differently, similar to + human decision-making patterns. + + Formula: + When y is chosen: + L_KTO = 1 - σ(β * (log[π(x)/π₀(x)] - KL(π||π₀)_y)) + When y is rejected: + L_KTO = 1 - σ(β * (KL(π||π₀)_y - log[π(x)/π₀(x)])) + + Where: + - σ: Sigmoid function + - β: Temperature parameter controlling the strength of the preference signal + - π(x): Policy (current model) + - π₀(x): Reference policy (reference model) + - KL(π||π₀)_y: KL divergence estimated using the rejected response y + + The loss encourages the model to: + 1. Assign higher probability to chosen responses + 2. Assign lower probability to rejected responses + 3. Maintain reasonable distance from the reference model + + Args: + log_prob_chunk: Log probabilities for the chunk (batch_size,) + preference_labels_chunk: Preference labels for the chunk (batch_size,) + full_target: Non chunked full target tensor + ref_log_prob_chunk: Reference log probs for the chunk (batch_size,) + beta: Weight for the KTO loss + kl: KL divergence between the policy model and the reference model for the chosen responses. Shape: (batch_size,) + Returns: + - loss: The KTO loss value + """ + if ref_log_prob_chunk is not None: + logratios_chunk = log_prob_chunk - ref_log_prob_chunk + else: + logratios_chunk = log_prob_chunk + multiplier_chunk = torch.where(preference_labels_chunk, 1, -1) + if kl is not None: + losses = 1 - F.sigmoid(beta * (logratios_chunk - kl) * multiplier_chunk) + else: + losses = 1 - F.sigmoid(beta * logratios_chunk * multiplier_chunk) + + rewards = beta * logratios_chunk + chosen_rewards_sum = (rewards * preference_labels_chunk.unsqueeze(1)).sum() + rejected_rewards_sum = (rewards * (~preference_labels_chunk).unsqueeze(1)).sum() + + return losses.sum() / (full_target.shape[0]), chosen_rewards_sum, rejected_rewards_sum + + @classmethod + def forward( + cls, + ctx, + _input, + weight, + target, + preference_labels, + bias=None, + ref_input=None, + ref_weight=None, + ref_bias=None, + kl=None, + ignore_index=-100, + beta=0.1, + compiled=True, + use_ref_model=True, + average_log_prob=False, + chunk_size=1, + ): + """ + Fused linear layer with KTO loss. + Args: + _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size) + weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size) + target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,) + preference_labels (torch.Tensor): Preference labels tensor. Shape: (batch_size,) + bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,) + ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size) + ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size) + ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,) + kl (torch.Tensor, optional): KL divergence tensor. Shape: (batch_size,) + ignore_index (int): Index to ignore in loss computation + beta (float): Temperature parameter for the KTO loss + compiled (bool): Whether to use torch compile + use_ref_model (bool): Whether to use a reference model + average_log_prob (bool): Whether to average the log probability per non-masked token + chunk_size (int): Size of chunks for processing + Returns: + torch.Tensor: Computed loss + """ + return super().forward( + cls=cls, + ctx=ctx, + _input=_input, + weight=weight, + target=target, + preference_labels=preference_labels, + bias=bias, + ignore_index=ignore_index, + beta=beta, + compiled=compiled, + use_ref_model=use_ref_model, + ref_input=ref_input, + ref_weight=ref_weight, + ref_bias=ref_bias, + average_log_prob=average_log_prob, + kl=kl, + chunk_size=chunk_size, + ) + + @staticmethod + def backward(ctx, *grad_output): + grads = LigerFusedLinearUnpairedPreferenceBase.backward(ctx, grad_output)[:5] + return ( + *grads, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +class LigerFusedLinearKTOLoss(torch.nn.Module): + """ + Fused linear layer with Kahneman-Tversky Optimization (KTO) loss. + """ + + def __init__( + self, + ignore_index: int = -100, + beta: float = 0.1, + compiled: bool = True, + use_ref_model: bool = False, + average_log_prob: bool = False, + chunk_size: int = 1, + ): + """ + Args: + ignore_index (int): Index to ignore in the loss calculation + beta (float): Temperature parameter for the KTO loss + compiled (bool): Whether to use compiled operations + use_ref_model (bool): Whether to use a reference model for the DPO loss. + average_log_prob (bool): Whether to average the log probability per non-masked token + chunk_size (int): Size of chunks for processing + """ + super().__init__() + self.ignore_index = ignore_index + self.beta = beta + self.compiled = compiled + self.use_ref_model = use_ref_model + self.average_log_prob = average_log_prob + self.chunk_size = chunk_size + + def forward( + self, + _input, + lin_weight, + target, + bias=None, + preference_labels=None, + ref_input=None, + ref_weight=None, + ref_bias=None, + kl=None, + ): + return LigerFusedLinearKTOFunction.apply( + _input, + lin_weight, + target, + preference_labels, + bias, + ref_input, + ref_weight, + ref_bias, + kl, + self.ignore_index, + self.beta, + self.compiled, + self.use_ref_model, + self.average_log_prob, + self.chunk_size, + ) diff --git a/src/liger_kernel/chunked_loss/orpo_loss.py b/src/liger_kernel/chunked_loss/orpo_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..7cb0bc3716444b5f77d8ab80cbee445495bd6f59 --- /dev/null +++ b/src/liger_kernel/chunked_loss/orpo_loss.py @@ -0,0 +1,144 @@ +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase + + +class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase): + @staticmethod + def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1): + """ + Paper: https://arxiv.org/pdf/2403.07691 + + Formula: + Compute odds-ratio loss: L_OR = -log(σ(log(odds_θ(y_w|x) / odds_θ(y_l|x)))) + where odds_θ(y|x) = P_θ(y|x) / (1 - P_θ(y|x)) + + Where: + - P_θ(y|x): Policy (model) probability + - y_w: Chosen sequence + - y_l: Rejected sequence + - σ: Sigmoid function + - β: Weight for the odds ratio loss + - odds_θ: Odds function for the policy + + Args: + chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). + rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). + full_target (torch.Tensor): Non chunked full target tensor + beta (float): Weight for the odds ratio loss. + """ + log_odds = (chosen_logps - rejected_logps) - ( + torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps)) + ) + ratio = F.logsigmoid(log_odds) + loss = -beta * ratio.sum() / (full_target.shape[0] // 2) + + chosen_rewards = beta * chosen_logps + rejected_rewards = beta * rejected_logps + + log_odds_ratio = torch.sum(ratio) / (full_target.shape[0] // 2) + log_odds_chosen = torch.sum(log_odds) / (full_target.shape[0] // 2) + + return loss, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen + + @classmethod + def forward( + cls, + ctx, + _input, + weight, + target, + bias=None, + ignore_index=-100, + beta=0.1, + compute_nll_loss=True, + nll_target=None, + compiled=True, + chunk_size=1, + ): + """ + Fused linear layer with ORPO loss. + Args: + _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size) + weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size) + target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,) + bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,) + ignore_index (int): Index to ignore in loss computation + beta (float): Weight for the odds ratio loss + compute_nll_loss (bool): Whether to compute the NLL loss + nll_target (torch.LongTensor, optional): Target tensor for NLL loss. Shape: (batch_size * seq_len,) + compiled (bool): Whether to use torch compile + chunk_size (int): Size of chunks for processing + Returns: + torch.Tensor: Computed loss + """ + return super().forward( + cls=cls, + ctx=ctx, + _input=_input, + weight=weight, + target=target, + bias=bias, + ignore_index=ignore_index, + beta=beta, + compute_nll_loss=compute_nll_loss, + nll_target=nll_target, + compiled=compiled, + chunk_size=chunk_size, + ) + + @staticmethod + def backward(ctx, *grad_output): + grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] + return *grads, None, None, None, None, None, None + + +class LigerFusedLinearORPOLoss(torch.nn.Module): + """ + Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss. + """ + + def __init__( + self, + ignore_index: int = -100, + beta: float = 0.1, + compute_nll_loss: bool = True, + compiled: bool = True, + chunk_size: int = 1, + ): + """ + Args: + ignore_index (int): Index to ignore in the loss. + beta (float): Weight for the odds ratio loss. + compute_nll_loss (bool): Whether to compute the NLL loss. + compiled (bool): Whether to use the torch compiled kernel. + chunk_size (int): Size of chunks for processing. + """ + super().__init__() + self.ignore_index = ignore_index + self.beta = beta + self.compute_nll_loss = compute_nll_loss + self.compiled = compiled + self.chunk_size = chunk_size + + def forward( + self, + lin_weight, + _input, + target, + bias=None, + nll_target=None, + ): + return LigerFusedLinearORPOFunction.apply( + _input, + lin_weight, + target, + bias, + self.ignore_index, + self.beta, + self.compute_nll_loss, + nll_target, + self.compiled, + self.chunk_size, + ) diff --git a/src/liger_kernel/chunked_loss/simpo_loss.py b/src/liger_kernel/chunked_loss/simpo_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..e24bb29b6ea82dbdb895020e5c953391cf9f86cf --- /dev/null +++ b/src/liger_kernel/chunked_loss/simpo_loss.py @@ -0,0 +1,165 @@ +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase + + +class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase): + @staticmethod + def preference_loss_fn( + chosen_logps, + rejected_logps, + full_target, + beta=0.1, + gamma=0.5, + label_smoothing=0.0, + ): + """ + Paper: https://arxiv.org/pdf/2405.14734 + + Formula: + L_SimPO(π_θ) = -E [log σ(β/|y_w| log π_θ(y_w|x) - β/|y_l| log π_θ(y_l|x) - γ)] + + Where: + - π_θ(y|x): Policy (model) probability + - y_w: Chosen sequence + - y_l: Rejected sequence + - |y_w|, |y_l|: Sequence lengths + - σ: Sigmoid function + - β: beta weight + - γ: gemma margin term + + Args: + chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). + rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). + full_target: Non chunked full target tensor + beta (float): beta weight + gamma (float): gemma margin term + label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0. + """ + logits = beta * (chosen_logps - rejected_logps) - gamma + loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / ( + full_target.shape[0] // 2 + ) + + chosen_rewards = beta * chosen_logps + rejected_rewards = beta * rejected_logps + + return loss, chosen_rewards, rejected_rewards + + @classmethod + def forward( + cls, + ctx, + _input, + weight, + target, + bias=None, + ignore_index=-100, + beta=0.1, + alpha=1.0, + label_smoothing=0.0, + compute_nll_loss=False, + compiled=True, + gamma=0.5, + chunk_size=1, + ): + """ + Fused linear layer with SimPO loss. + Args: + _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size) + weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size) + target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,) + bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,) + ignore_index (int): Index to ignore in loss computation + beta (float): Weight for the odds ratio loss + alpha (float): Weight for the alpha parameter + label_smoothing (float): Label smoothing factor + compute_nll_loss (bool): Whether to compute the NLL loss + compiled (bool): Whether to use torch compile + gamma (float): Weight for the gamma parameter + chunk_size (int): Size of chunks for processing + Returns: + torch.Tensor: Computed loss + """ + return super().forward( + cls=cls, + ctx=ctx, + _input=_input, + weight=weight, + target=target, + bias=bias, + ignore_index=ignore_index, + alpha=alpha, + beta=beta, + label_smoothing=label_smoothing, + compute_nll_loss=compute_nll_loss, + compiled=compiled, + gamma=gamma, + chunk_size=chunk_size, + ) + + @staticmethod + def backward(ctx, *grad_output): + grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] + return *grads, None, None, None, None, None, None, None, None + + +class LigerFusedLinearSimPOLoss(torch.nn.Module): + """ + Fused linear layer with SimPO loss. + """ + + def __init__( + self, + ignore_index: int = -100, + beta: float = 0.1, + alpha: float = 1.0, + label_smoothing: float = 0.0, + compute_nll_loss: bool = True, + compiled: bool = True, + gamma: float = 0.5, + chunk_size: int = 1, + ): + """ + Args: + ignore_index (int): Index to ignore in the loss. + beta (float): Weight for the odds ratio loss. + alpha (float): Weight for the alpha parameter. + label_smoothing (float): Label smoothing factor. + compute_nll_loss (bool): Whether to compute the NLL loss. + compiled (bool): Whether to use the torch compiled kernel. + gamma (float): Weight for the gamma parameter. + chunk_size (int): Size of chunks for processing. + """ + super().__init__() + self.ignore_index = ignore_index + self.beta = beta + self.alpha = alpha + self.label_smoothing = label_smoothing + self.compute_nll_loss = compute_nll_loss + self.compiled = compiled + self.gamma = gamma + self.chunk_size = chunk_size + + def forward( + self, + lin_weight, + _input, + target, + bias=None, + ): + return LigerFusedLinearSimPOFunction.apply( + _input, + lin_weight, + target, + bias, + self.ignore_index, + self.beta, + self.alpha, + self.label_smoothing, + self.compute_nll_loss, + self.compiled, + self.gamma, + self.chunk_size, + ) diff --git a/src/liger_kernel/env_report.py b/src/liger_kernel/env_report.py new file mode 100755 index 0000000000000000000000000000000000000000..ff31855090ffffa54dd42a6d70ab333739cc2c55 --- /dev/null +++ b/src/liger_kernel/env_report.py @@ -0,0 +1,63 @@ +import platform +import sys + +from importlib.metadata import version + + +def print_env_report(): + """ + + Prints a report of the environment. Useful for debugging and reproducibility. + Usage: + ``` + python -m liger_kernel.env_report + ``` + + """ + print("Environment Report:") + print("-------------------") + print(f"Operating System: {platform.platform()}") + print(f"Python version: {sys.version.split()[0]}") + + try: + print(f"Liger Kernel version: {version('liger-kernel')}") + except ImportError: + print("Liger Kernel: Not installed") + + try: + import torch + + print(f"PyTorch version: {torch.__version__}") + cuda_version = torch.version.cuda if torch.cuda.is_available() else "Not available" + print(f"CUDA version: {cuda_version}") + hip_version = torch.version.hip if torch.cuda.is_available() and torch.version.hip else "Not available" + print(f"HIP(ROCm) version: {hip_version}") + + except ImportError: + print("PyTorch: Not installed") + print("CUDA version: Unable to query") + print("HIP(ROCm) version: Unable to query") + + try: + import triton + + print(f"Triton version: {triton.__version__}") + except ImportError: + print("Triton: Not installed") + + try: + import transformers + + print(f"Transformers version: {transformers.__version__}") + except ImportError: + print("Transformers: Not installed") + + try: + xpu_version = torch.version.xpu if torch.xpu.is_available() else "XPU Not Available" + print(f"XPU version: {xpu_version}") + except ImportError: + print("XPU version: Unable to query") + + +if __name__ == "__main__": + print_env_report() diff --git a/src/liger_kernel/ops/__init__.py b/src/liger_kernel/ops/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..cc7d0b038ba8e049f29678580c9afc342ba89411 --- /dev/null +++ b/src/liger_kernel/ops/__init__.py @@ -0,0 +1,144 @@ +""" +Liger-Kernel operators with automatic vendor-specific replacement. + +This module provides two ways to import operators: + +1. Import from this package (recommended for Function classes): + from liger_kernel.ops import LigerGELUMulFunction + + This automatically uses vendor-specific implementation if available. + +2. Import from submodules (for kernel functions or specific access): + from liger_kernel.ops.geglu import geglu_forward, geglu_backward + + This always uses the default implementation (no auto-replacement). + +The replacement mechanism: +1. Default implementations are imported from individual modules (e.g., geglu.py) +2. On module load, device is detected via infer_device() +3. If running on a supported vendor device (npu, xpu, etc.), the default + implementations are replaced with vendor-specific ones +4. All subsequent imports from this package get the replaced versions + +Note: Direct imports from submodules (e.g., from liger_kernel.ops.geglu import ...) + are NOT affected by the replacement mechanism. +""" + +# ============================================================================= +# Import default implementations +# Both Function classes and kernel functions are imported here. +# All of these can be replaced by vendor-specific implementations. +# ============================================================================= + +from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction # noqa: F401 +from liger_kernel.ops.cross_entropy import cross_entropy_backward # noqa: F401 +from liger_kernel.ops.cross_entropy import cross_entropy_forward # noqa: F401 +from liger_kernel.ops.dyt import LigerDyTFunction # noqa: F401 +from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction # noqa: F401 +from liger_kernel.ops.fused_add_rms_norm import LigerFusedAddRMSNormFunction # noqa: F401 +from liger_kernel.ops.fused_add_rms_norm import fused_add_rms_norm_backward # noqa: F401 +from liger_kernel.ops.fused_add_rms_norm import fused_add_rms_norm_forward # noqa: F401 +from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction # noqa: F401 +from liger_kernel.ops.fused_linear_cross_entropy import fused_linear_cross_entropy_backward # noqa: F401 +from liger_kernel.ops.fused_linear_cross_entropy import fused_linear_cross_entropy_forward # noqa: F401 +from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction # noqa: F401 +from liger_kernel.ops.fused_linear_jsd import fused_linear_jsd_backward # noqa: F401 +from liger_kernel.ops.fused_linear_jsd import fused_linear_jsd_forward # noqa: F401 +from liger_kernel.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction # noqa: F401 +from liger_kernel.ops.geglu import LigerGELUMulFunction # noqa: F401 +from liger_kernel.ops.geglu import geglu_backward # noqa: F401 +from liger_kernel.ops.geglu import geglu_forward # noqa: F401 +from liger_kernel.ops.group_norm import LigerGroupNormFunction # noqa: F401 +from liger_kernel.ops.group_norm import group_norm_backward # noqa: F401 +from liger_kernel.ops.group_norm import group_norm_forward # noqa: F401 +from liger_kernel.ops.grpo_loss import GrpoLossFunction # noqa: F401 +from liger_kernel.ops.jsd import LigerJSDFunction # noqa: F401 +from liger_kernel.ops.jsd import jsd_backward # noqa: F401 +from liger_kernel.ops.jsd import jsd_forward # noqa: F401 +from liger_kernel.ops.kl_div import LigerKLDivLossFunction # noqa: F401 +from liger_kernel.ops.layer_norm import LigerLayerNormFunction # noqa: F401 +from liger_kernel.ops.layer_norm import layer_norm_backward # noqa: F401 +from liger_kernel.ops.layer_norm import layer_norm_forward # noqa: F401 +from liger_kernel.ops.llama4_rope import LigerLlama4RopeFunction # noqa: F401 +from liger_kernel.ops.mhc import LigerMHCCoeffsFunction # noqa: F401 +from liger_kernel.ops.mhc import LigerMHCPostResFunction # noqa: F401 +from liger_kernel.ops.mhc import LigerMHCPreFunction # noqa: F401 +from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction # noqa: F401 +from liger_kernel.ops.poly_norm import LigerPolyNormFunction # noqa: F401 +from liger_kernel.ops.poly_norm import poly_norm_backward # noqa: F401 +from liger_kernel.ops.poly_norm import poly_norm_forward # noqa: F401 +from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction # noqa: F401 +from liger_kernel.ops.rms_norm import LigerRMSNormFunction # noqa: F401 +from liger_kernel.ops.rms_norm import rms_norm_backward # noqa: F401 +from liger_kernel.ops.rms_norm import rms_norm_forward # noqa: F401 +from liger_kernel.ops.rope import LigerRopeFunction # noqa: F401 +from liger_kernel.ops.rope import rope_backward # noqa: F401 +from liger_kernel.ops.rope import rope_forward # noqa: F401 +from liger_kernel.ops.softmax import LigerSoftmaxFunction # noqa: F401 +from liger_kernel.ops.sparsemax import LigerSparsemaxFunction # noqa: F401 +from liger_kernel.ops.swiglu import LigerSiLUMulFunction # noqa: F401 +from liger_kernel.ops.swiglu import swiglu_backward # noqa: F401 +from liger_kernel.ops.swiglu import swiglu_forward # noqa: F401 +from liger_kernel.ops.tiled_mlp import LigerTiledMLPFunction # noqa: F401 +from liger_kernel.ops.tiled_mlp import apply_tiled_mlp # noqa: F401 +from liger_kernel.ops.tvd import LigerTVDLossFunction # noqa: F401 + +# NOTE: __all__ is intentionally NOT defined. +# - Import from this package (liger_kernel.ops) -> subject to vendor replacement +# - Import from submodules (liger_kernel.ops.geglu) -> always use default implementation + + +# ============================================================================= +# Vendor-specific replacement logic +# ============================================================================= + + +def _replace_with_vendor_ops(): + """ + Replace/add vendor-specific operator implementations. + + This function is called automatically on module load. It: + 1. Detects the current device (cuda, npu, xpu, etc.) + 2. Looks up the vendor for that device via VENDOR_REGISTRY + 3. Loads and applies vendor-specific implementations + + Vendor implementations should be placed in: + liger_kernel/ops/backends/_/ops/ + + If the vendor module defines __all__, only those symbols are exported. + Otherwise, all public symbols (not starting with _) are auto-discovered. + + Note: Vendor can both override existing ops AND add new vendor-specific ops. + """ + from liger_kernel.ops.backends import get_vendor_for_device + from liger_kernel.utils import infer_device + + device = infer_device() + + # Look up vendor info for this device + vendor_info = get_vendor_for_device(device) + if vendor_info is None: + return + + try: + import importlib + + vendor_ops = importlib.import_module(vendor_info.module_path) + + # Get names to export: use __all__ if defined, otherwise auto-discover + names_to_export = getattr(vendor_ops, "__all__", None) + + if names_to_export is None: + # Auto-discover: find all public symbols (classes and functions) + names_to_export = [name for name in dir(vendor_ops) if not name.startswith("_")] + + # Replace or add to this module's globals + for name in names_to_export: + globals()[name] = getattr(vendor_ops, name) + + except ImportError: + # Vendor module not available, use default implementations + pass + + +_replace_with_vendor_ops() diff --git a/src/liger_kernel/ops/backends/README.md b/src/liger_kernel/ops/backends/README.md new file mode 100755 index 0000000000000000000000000000000000000000..d4067157b60be225fc361acca9dea025d63d757c --- /dev/null +++ b/src/liger_kernel/ops/backends/README.md @@ -0,0 +1,151 @@ +# Adding a New Vendor Backend + +This directory contains vendor-specific operator implementations that automatically replace the default (CUDA) implementations when running on the corresponding device. + +## Concepts + +- **Vendor**: Chip manufacturer (e.g., `ascend`, `intel`, `nvidia`) +- **Device**: Device type (e.g., `npu`, `xpu`, `cuda`) +- **VendorInfo**: Defines the mapping between vendor and device + +## Directory Structure + +``` +backends/ +├── README.md +├── __init__.py +├── registry.py # VendorInfo, register_vendor(), VENDOR_REGISTRY +├── _ascend/ # Ascend (Huawei) vendor - supports NPU +│ ├── __init__.py # Registers VendorInfo for NPU +│ └── ops/ +│ ├── __init__.py # Exports vendor-specific implementations +│ └── geglu.py # NPU-specific GEGLU implementation +└── _/ # Your new vendor backend + └── ... +``` + +## How It Works + +1. When `liger_kernel.ops.backends` is imported, it imports all vendor packages (e.g., `_ascend`) +2. Each vendor's `__init__.py` calls `register_vendor()` to register itself +3. When `liger_kernel.ops` is imported, `_replace_with_vendor_ops()` is called +4. It detects the current device via `infer_device()` and looks up the vendor +5. Vendor implementations replace/add to the `liger_kernel.ops` namespace + +## Adding a New Vendor + +### Step 1: Create Directory Structure + +```bash +mkdir -p backends/_/ops +touch backends/_/__init__.py +touch backends/_/ops/__init__.py +``` + +### Step 2: Register Your Vendor + +In `backends/_/__init__.py`, register your vendor: + +```python +""" + backend for Liger-Kernel. +""" + +from liger_kernel.ops.backends.registry import VendorInfo, register_vendor + +register_vendor( + VendorInfo( + vendor="", + device="", + ) +) +``` + + +### Step 3: Ensure Device Detection Works + +Make sure `infer_device()` in `liger_kernel/utils.py` can detect your device: + +```python +def infer_device(): + if torch.cuda.is_available(): + return "cuda" + if is_npu_available(): + return "npu" + # Add your device detection here + if is__available(): + return "" + return "cpu" +``` + +### Step 4: Implement Vendor-Specific Operators + +Create operator files in `backends/_/ops/`. For example, `geglu.py`: + +```python +import torch + +class LigerGELUMulFunction(torch.autograd.Function): + """ + Vendor-specific LigerGELUMulFunction implementation. + """ + @staticmethod + def forward(ctx, a, b): + # Your vendor-specific forward implementation + ... + + @staticmethod + def backward(ctx, dc): + # Your vendor-specific backward implementation + ... + +# Optional: vendor-specific kernel functions +def geglu_forward_vendor(a, b): + ... + +def geglu_backward_vendor(a, b, dc): + ... +``` + +### Step 5: Export in `ops/__init__.py` + +In `backends/_/ops/__init__.py`, export your implementations: + +```python +""" +-specific operator implementations. +""" + +from . import ( + LigerGELUMulFunction, + geglu_forward_vendor as geglu_forward, # Rename to match default API + geglu_backward_vendor as geglu_backward, +) + +# Explicitly declare what to export (recommended) +__all__ = [ + "LigerGELUMulFunction", + "geglu_forward", + "geglu_backward", +] +``` + +## Key Points + +### Incremental Override + +You **don't need to implement all operators**. Only implement the ones that require vendor-specific adaptations. Unimplemented operators will automatically fall back to the default (CUDA) implementation. + +### Vendor-Specific Additions + +Vendors can also **add new operators** that don't exist in the default implementation. These will be exported to `liger_kernel.ops` namespace for users to import. + +### Naming Convention + +- Use the **same class/function names** as the default implementations for overrides +- This allows seamless replacement without changing user code +- Use `as` imports to rename if your internal naming differs + +## Example: Ascend NPU Backend + +See `_ascend/` directory for a complete example of the Ascend NPU backend implementation. diff --git a/src/liger_kernel/ops/backends/__init__.py b/src/liger_kernel/ops/backends/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..ad7779c48afdc2efb3783ad80a87224c407e1c28 --- /dev/null +++ b/src/liger_kernel/ops/backends/__init__.py @@ -0,0 +1,13 @@ +import importlib +import pkgutil + +from liger_kernel.ops.backends.registry import VENDOR_REGISTRY # noqa: F401 +from liger_kernel.ops.backends.registry import VendorInfo # noqa: F401 +from liger_kernel.ops.backends.registry import get_vendor_for_device # noqa: F401 +from liger_kernel.ops.backends.registry import register_vendor # noqa: F401 + +# Auto-import all _ subpackages to trigger registration +# Each vendor's __init__.py calls register_vendor() when imported +for _, modname, ispkg in pkgutil.iter_modules(__path__): + if ispkg and modname.startswith("_"): + importlib.import_module(f"{__name__}.{modname}") diff --git a/src/liger_kernel/ops/backends/_ascend/__init__.py b/src/liger_kernel/ops/backends/_ascend/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..a07e7ab09e89f8781949bb4d6d6b1fb1e27e3344 --- /dev/null +++ b/src/liger_kernel/ops/backends/_ascend/__init__.py @@ -0,0 +1,5 @@ +from liger_kernel.ops.backends.registry import VendorInfo +from liger_kernel.ops.backends.registry import register_vendor + +# Register Ascend vendor for NPU device +register_vendor(VendorInfo(vendor="ascend", device="npu")) diff --git a/src/liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md b/src/liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md new file mode 100755 index 0000000000000000000000000000000000000000..bf9faa3dc09ce5b3c57aa1a7827f4faa0e1fab61 --- /dev/null +++ b/src/liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md @@ -0,0 +1,492 @@ +# Ascend NPU UB Manager Design Document + +## Overview + +The UB Manager (Unified Buffer Manager) is a core component in **Liger-Kernel** responsible for managing the Unified Buffer (UB) capacity on Ascend NPUs. By automatically detecting UB capacity and providing unified tiling strategy computation, it helps Triton kernels avoid UB overflow errors while maintaining high performance. + +## Design Goals + +1. **Automated UB Management**: Automatically detect device UB capacity without manual configuration +2. **Unified Strategy System**: Use a single unified strategy function for all kernels, abstracting memory calculations +3. **Flexible Parameters**: Support different memory multipliers and safety margins for different kernels +4. **Easy to Use**: Simple interface that directly computes tiling results + +## Architecture Design + +### Core Components + +``` +┌─────────────────────────────────────────────────────────┐ +│ UB Manager System │ +├─────────────────────────────────────────────────────────┤ +│ │ +│ ┌──────────────┐ ┌──────────────────┐ │ +│ │ UBManager │ │ Default Strategy │ │ +│ │ (Singleton)│────────▶│ Function │ │ +│ └──────────────┘ └──────────────────┘ │ +│ │ │ │ +│ │ │ │ +│ ▼ ▼ │ +│ ┌──────────────┐ ┌──────────────────┐ │ +│ │ Capacity │ │ compute_default │ │ +│ │ Detection │ │ _tiling_strategy│ │ +│ └──────────────┘ └──────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────┘ + │ │ + │ │ + ▼ ▼ +┌──────────────┐ ┌──────────────────┐ +│ GEGLU │ │ ROPE │ +│ Kernel │ │ Kernel │ +└──────────────┘ └──────────────────┘ +``` + +### Class Diagram + +``` +┌──────────────────────────────────────┐ +│ UBManager │ +├──────────────────────────────────────┤ +│ - _npu_model: str │ +│ - _ub_capacity_bits: int │ +├──────────────────────────────────────┤ +│ + ub_capacity_bits: int │ +│ + ub_capacity_bytes: int │ +│ + npu_model: str │ +│ - _detect_npu_model() │ +│ - _detect_ub_capacity() │ +│ (raises RuntimeError if fails) │ +└──────────────────────────────────────┘ + +┌──────────────────────────────────────┐ +│ compute_default_tiling_strategy │ +├──────────────────────────────────────┤ +│ + safety_margin: float │ +│ + dtype_size: int │ +│ + memory_multiplier: float │ +│ + shapes: Tuple[Tuple[int, ...], ...]│ +│ + tiling_dims: Tuple │ +├──────────────────────────────────────┤ +│ Returns: Tuple[Tuple[int, ...], ...] │ +│ (same structure as shapes) │ +└──────────────────────────────────────┘ + +┌──────────────────────────────────────┐ +│ _normalize_tiling_dims │ +├──────────────────────────────────────┤ +│ Helper function to normalize │ +│ tiling_dim (int or tuple) to set │ +└──────────────────────────────────────┘ +``` + +## Core Functionality + +### 1. UB Capacity Detection + +The UB Manager detects UB capacity in the following priority order: + +1. **Environment Variable**: `ASCEND_UB_CAPACITY_BITS` (in bits) + - If set, this value is used directly + - Must be a positive integer representing UB capacity in bits + +2. **get_soc_spec**: Query UB size from CANN's `get_soc_spec("UB_SIZE")` + - Returns UB size in bytes + - Automatically converted to bits (bytes * 8) + - Requires CANN environment to be sourced (e.g., `source /usr/local/Ascend/ascend-toolkit/set_env.sh`) + +3. **Error Handling**: If neither method succeeds, raises `RuntimeError` with clear instructions + + +```python +# Detection flow: +# 1. Check ASCEND_UB_CAPACITY_BITS env var (bits) +# 2. Try get_soc_spec("UB_SIZE") (bytes) -> convert to bits +# 3. Raise RuntimeError if both fail +``` + +### 2. Unified Strategy System + +All kernels use a single unified strategy function `_default_strategy` that abstracts memory calculations: + +``` +Memory Formula: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 bits +``` + +Where `unit_param` is automatically calculated as the product of all fixed (non-tiling) dimensions in each shape. + +The strategy function: +- Takes UB capacity, safety margin, dtype size, memory multiplier, shapes, and tiling dimension specifications +- For each shape, identifies which dimensions can be tiled (from `tiling_dims`) +- Calculates `unit_param` as the product of fixed (non-tiling) dimensions +- Calculates the maximum safe block size that fits within UB capacity +- Returns a tuple of max_safe_block_size values (one for each shape) + +The `compute_default_tiling_strategy` function: +- Calls `_default_strategy` to get max_safe_block_size for each shape +- For each tiling dimension, computes desired block size using `triton.next_power_of_2(original_dim)` +- Returns the final result with same structure as input shapes: tiling dimensions replaced with computed block sizes, non-tiling dimensions padded to next power of 2 + +### 3. Parameter Structure + +The unified strategy uses the following parameters: + +- **`safety_margin`**: Safety margin as a float (e.g., 0.80 for 80%). Default is 0.80. +- **`dtype_size`**: Size of data type in bytes (e.g., 2 for float16, 4 for float32) +- **`memory_multiplier`**: Memory multiplier for estimating peak memory usage + - For GEGLU: typically 10.0 for backward, 7.0 for forward + - For ROPE: typically 3.0 +- **`shapes`**: Tuple of full shapes. Each shape is a tuple of dimension sizes. + - For ROPE: `((n_q_head, hd), (n_kv_head, hd))` + - For GEGLU: `((n_cols,),)` + - Can pass original shapes (will handle padding internally) or padded shapes +- **`tiling_dims`**: Tuple specifying which dimensions can be tiled for each shape. + - Each element can be: + - `int`: single dimension index (e.g., `0` for first dimension) + - `tuple of ints`: multiple dimensions that can be tiled together (non-empty) + - For ROPE: `(0, 0)` means first dimension of each shape can be tiled + - For GEGLU: `(0,)` means first dimension of the shape can be tiled + - Length must match `len(shapes)` + - Fixed dimensions (non-tiling) are automatically extracted from shapes and multiplied to get `unit_param` + - **Validation**: Raises `ValueError` if: + - Any `tiling_dim` is empty or invalid (e.g., empty tuple) + - Any dimension index is out of bounds (negative or >= shape length) + +### 4. Strategy Computation Flow + +``` +User calls compute_default_tiling_strategy() + │ + ▼ +Get UB manager instance + │ + ▼ +Validate shapes and tiling_dims (lengths must match) + │ + ▼ +Set defaults for dtype_size (4) and memory_multiplier (10.0) + │ + ▼ +Call _default_strategy() with: + - ub_capacity_bits + - safety_margin + - dtype_size + - memory_multiplier + - shapes + - tiling_dims + │ + ▼ +For each (shape, tiling_dim) pair: + Normalize tiling_dim to set of dimension indices + Validate tiling dimensions are within shape bounds + (Raises ValueError if invalid) + │ + ▼ + Calculate unit_param: + unit_param = product of all non-tiling dimensions + │ + ▼ + Calculate max_block_size: + SAFE_UB_CAPACITY_BITS = ub_capacity_bits * safety_margin + max_block_size = SAFE_UB_CAPACITY_BITS / (memory_multiplier * unit_param * dtype_size * 8) + │ + ▼ + Find largest power of 2 <= max_block_size + │ + ▼ +Return tuple of max_safe_block_size (one per shape) + │ + ▼ +Build result with same structure as shapes: + For each (shape, tiling_dim, max_safe): + For each tiling dimension: + desired = triton.next_power_of_2(original_dim) + final = min(desired, max_safe) + final = max(1, final) + For each non-tiling dimension: + pad to triton.next_power_of_2(original_dim) + │ + ▼ +Return tuple of tiled shapes +``` + +## Usage Examples + +### Basic Usage + +```python +from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy + +# GEGLU forward +shapes = ((4096,),) +tile_shapes = compute_default_tiling_strategy( + safety_margin=0.80, + dtype_size=2, # float16 + memory_multiplier=7.0, + shapes=shapes, + tiling_dims=(0,) # First dimension can be tiled +) +if tile_shapes is not None and len(tile_shapes) > 0: + block_size = tile_shapes[0][0] + # Call kernel with block_size + +# ROPE forward +shapes = ((32, 128), (32, 128)) # (n_q_head, hd), (n_kv_head, hd) +tile_shapes = compute_default_tiling_strategy( + safety_margin=0.90, + dtype_size=4, # float32 + memory_multiplier=3.0, + shapes=shapes, + tiling_dims=(0, 0) # First dimension of each shape can be tiled +) +if tile_shapes is not None and len(tile_shapes) == len(shapes): + q_tile_shape, k_tile_shape = tile_shapes + BLOCK_Q, _ = q_tile_shape # Tiled dimension + BLOCK_K, _ = k_tile_shape # Tiled dimension + # Call kernel with BLOCK_Q and BLOCK_K +``` + +## Strategy Function Details + +### `_normalize_tiling_dims` Helper Function + +A helper function that normalizes tiling dimension specifications: + +```python +def _normalize_tiling_dims(tiling_dim: Union[int, Tuple[int, ...]]) -> set: + """ + Normalize tiling dimension specification to a set of dimension indices. + + Args: + tiling_dim: Either an int (single dimension) or tuple of ints (multiple dimensions). + + Returns: + Set of dimension indices that can be tiled. + """ +``` + +This function handles the conversion of `tiling_dim` from either an `int` or `tuple` to a `set` for consistent processing. + +### `_default_strategy` Function + +The core strategy function that calculates maximum safe block size: + +```python +def _default_strategy( + ub_capacity_bits: int, + safety_margin: float, + dtype_size: int, + memory_multiplier: float, + shapes: Tuple[Tuple[int, ...], ...], + tiling_dims: Tuple[Union[int, Tuple[int, ...]], ...], +) -> Tuple[int, ...]: + """ + Calculate maximum safe block size based on UB capacity. + + Memory formula: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 bits + + For each shape, fixed dimensions (non-tiling) are multiplied together to get unit_param. + + Returns: + Tuple of max_safe_block_size (power of 2), one for each shape. + + Raises: + ValueError: If any tiling_dim is empty or invalid, or if any dimension + index is out of bounds for the corresponding shape. + """ +``` + +**Key Steps:** +1. For each `(shape, tiling_dim)` pair: + - Normalize `tiling_dim` to a set of dimension indices using `_normalize_tiling_dims` + - Validate tiling dimensions are within shape bounds + - Raises `ValueError` if `tiling_dim` is empty or invalid + - Raises `ValueError` if any dimension index is out of bounds + - Calculate `unit_param` as the product of all non-tiling dimensions + - If all dimensions are tiling, `unit_param = 1.0` +2. Calculate `SAFE_UB_CAPACITY_BITS = ub_capacity_bits * safety_margin` +3. Solve for max_block_size: `SAFE_UB_CAPACITY_BITS / (memory_multiplier * unit_param * dtype_size * 8)` +4. Find largest power of 2 <= max_block_size +5. Return tuple with one max_safe_block_size per shape + +### `compute_default_tiling_strategy` Function + +The public interface that computes final tiling results: + +```python +def compute_default_tiling_strategy( + safety_margin: float = 0.80, + dtype_size: Optional[int] = None, + memory_multiplier: Optional[float] = None, + shapes: Optional[Tuple[Tuple[int, ...], ...]] = None, + tiling_dims: Optional[Tuple[Union[int, Tuple[int, ...]], ...]] = None, +) -> Optional[Tuple[Tuple[int, ...], ...]]: + """ + Compute tiling strategy using the default strategy function. + + Returns tuple of tiled shapes with same structure as input shapes. + Tiling dimensions are replaced with computed block sizes (power of 2), + while non-tiling dimensions are padded to next power of 2. + + Returns: + Tuple of tiled shapes, or None if shapes/tiling_dims are empty or + lengths don't match. + + Raises: + ValueError: If any tiling_dim is empty or invalid, or if any dimension + index is out of bounds for the corresponding shape. + """ +``` + +**Key Steps:** +1. Get UB manager instance +2. Validate `shapes` and `tiling_dims` (lengths must match, cannot be empty) + - Returns `None` if validation fails (empty or mismatched lengths) +3. Set defaults for `dtype_size` (4) and `memory_multiplier` (10.0) if not provided +4. Call `_default_strategy` to get `max_supported` (tuple of max_safe_block_size, one per shape) + - May raise `ValueError` if `tiling_dims` are invalid (see `_default_strategy` documentation) +5. For each `(shape, tiling_dim, max_safe)`: + - Normalize `tiling_dim` to a set of dimension indices + - Validate tiling dimensions are within shape bounds + - Raises `ValueError` if `tiling_dim` is empty or invalid + - Raises `ValueError` if any dimension index is out of bounds + - For each tiling dimension: + - Compute `desired = triton.next_power_of_2(original_dim)` + - Compute `final = min(desired, max_safe)` + - Ensure `final >= 1` + - Replace dimension with `final` + - For each non-tiling dimension: + - Pad to `triton.next_power_of_2(original_dim)` +6. Return tuple of tiled shapes (same structure as input `shapes`) + +## Memory Analysis Examples + +### GEGLU Forward + +``` +Memory analysis: +- Inputs: a, b +- Intermediates: a_cubed, tanh_arg, tanh_result, geglu_a +- Output: c +- Total: ~7x * BLOCK_SIZE * dtype_size + +Strategy: +- shapes: ((n_cols,),) +- tiling_dims: (0,) # First dimension can be tiled +- Fixed dimensions: none (all dimensions are tiling) +- unit_param = 1 (product of fixed dimensions) +- memory_multiplier = 7.0 +- Formula: 7.0 * BLOCK_SIZE * 1 * dtype_size * 8 bits +- Returns: ((block_size,),) +``` + +### GEGLU Backward + +``` +Memory analysis: +- More intermediates for gradient computation +- Total: ~10x * BLOCK_SIZE * dtype_size + +Strategy: +- shapes: ((n_cols,),) +- tiling_dims: (0,) # First dimension can be tiled +- Fixed dimensions: none (all dimensions are tiling) +- unit_param = 1 (product of fixed dimensions) +- memory_multiplier = 10.0 +- Formula: 10.0 * BLOCK_SIZE * 1 * dtype_size * 8 bits +- Returns: ((block_size,),) +``` + +### ROPE Forward/Backward + +``` +Memory analysis (based on optimized ROPE kernel): +- cos_vals and sin_vals: pad_hd // 2 elements each (shared) +- In q heads loop (peak memory): + * q_left, q_right, new_left, new_right: 2 * BLOCK_Q * pad_hd elements +- In k heads loop (peak memory): + * k_left, k_right, new_left, new_right: 2 * BLOCK_K * pad_hd elements +- Plus shared cos/sin: pad_hd elements +- Conservative estimate: 3 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits + +Strategy: +- shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd)) +- tiling_dims: (0, 0) # First dimension of each shape can be tiled +- Fixed dimensions: pad_hd (second dimension, non-tiling) +- unit_param = pad_hd (product of fixed dimensions) +- memory_multiplier = 3.0 +- Formula: 3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits +- Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd)) +``` + +## Extension Guide + +### Adding a New Kernel + +To add tiling support for a new kernel: + +1. **Analyze memory usage**: + - Identify peak memory usage in the kernel + - Determine memory multiplier (e.g., 7.0, 10.0, 3.0) + - Identify which dimensions can be tiled and which are fixed + - Fixed dimensions will be automatically extracted and multiplied to get `unit_param` + +2. **Use `compute_default_tiling_strategy`** in your kernel: + +```python +def my_kernel_forward(input): + # Prepare parameters + n_cols = input.shape[-1] + dtype_size = input.element_size() + + # Compute strategy + # Example 1: Simple case (all dimensions can be tiled) + shapes = ((n_cols,),) + tile_shapes = compute_default_tiling_strategy( + safety_margin=0.80, + dtype_size=dtype_size, + memory_multiplier=7.0, # Based on your memory analysis + shapes=shapes, + tiling_dims=(0,) # First dimension can be tiled + ) + + if tile_shapes is not None and len(tile_shapes) > 0: + block_size = tile_shapes[0][0] + else: + block_size = triton.next_power_of_2(n_cols) # Fallback + + # Example 2: Multiple shapes with fixed dimensions + # shapes = ((M, K), (K, N)) + # tiling_dims = (0, 1) # First shape: dim 0 can be tiled, dim 1 is fixed + # # Second shape: dim 0 is fixed, dim 1 can be tiled + # Returns: ((block_M, K), (K, block_N)) + + # Call kernel + kernel[(grid_size,)]( + input, + BLOCK_SIZE=block_size, + ) +``` + +3. **Document memory analysis** in comments: + +```python +# My kernel tiling strategy: +# - Memory analysis: +# * Input: input +# * Intermediates: intermediate1, intermediate2 +# * Output: output +# * Total: ~7x * BLOCK_SIZE * dtype_size +# - shapes: ((n_cols,),) +# - tiling_dims: (0,) means first dimension can be tiled +# - Fixed dimensions: none (all dimensions are tiling) +# - unit_param = 1 (product of fixed dimensions) +# - Uses memory_multiplier=7.0 * BLOCK_SIZE * dtype_size * 8 bits for safety +# - compute_default_tiling_strategy returns: ((block_size,),) +# where block_size = min(triton.next_power_of_2(n_cols), max_safe_block_size) +``` + +## Future Improvements + +1. **Strategy Variants**: If needed, could add specialized strategy functions for specific kernels while keeping the unified interface +2. **Multi-dimensional Tiling**: Could extend to support more complex tiling patterns if needed diff --git a/src/liger_kernel/ops/backends/_ascend/ops/__init__.py b/src/liger_kernel/ops/backends/_ascend/ops/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..38fcc7ee1c9677373665cfff5071d13cc3dd0dcc --- /dev/null +++ b/src/liger_kernel/ops/backends/_ascend/ops/__init__.py @@ -0,0 +1,139 @@ +""" +Ascend NPU operator implementations. + +This module exports Ascend NPU-optimized implementations that will automatically +replace the default implementations when running on NPU devices. + +Both Function classes and kernel functions can be exported here. + +To add a new operator: +1. Create the implementation file (e.g., rms_norm.py) +2. Import the Function class and/or kernel functions here +3. Optionally add to __all__ for explicit control + +If __all__ is not defined, all public symbols will be auto-discovered. +""" + +from liger_kernel.ops.backends._ascend.ops.cross_entropy import LigerCrossEntropyFunction +from liger_kernel.ops.backends._ascend.ops.cross_entropy import cross_entropy_backward +from liger_kernel.ops.backends._ascend.ops.cross_entropy import cross_entropy_forward +from liger_kernel.ops.backends._ascend.ops.dyt import LigerDyTFunction +from liger_kernel.ops.backends._ascend.ops.dyt import liger_dyt_bwd +from liger_kernel.ops.backends._ascend.ops.dyt import liger_dyt_fwd +from liger_kernel.ops.backends._ascend.ops.embedding import LigerEmbeddingFunction +from liger_kernel.ops.backends._ascend.ops.embedding import embedding_backward +from liger_kernel.ops.backends._ascend.ops.embedding import embedding_forward +from liger_kernel.ops.backends._ascend.ops.fused_add_rms_norm import LigerFusedAddRMSNormFunction +from liger_kernel.ops.backends._ascend.ops.fused_add_rms_norm import fused_add_rms_norm_backward +from liger_kernel.ops.backends._ascend.ops.fused_add_rms_norm import fused_add_rms_norm_forward +from liger_kernel.ops.backends._ascend.ops.fused_linear_jsd import LigerFusedLinearJSDFunction +from liger_kernel.ops.backends._ascend.ops.fused_linear_jsd import fused_linear_jsd_backward +from liger_kernel.ops.backends._ascend.ops.fused_linear_jsd import fused_linear_jsd_forward +from liger_kernel.ops.backends._ascend.ops.geglu import LigerGELUMulFunction +from liger_kernel.ops.backends._ascend.ops.geglu import geglu_backward +from liger_kernel.ops.backends._ascend.ops.geglu import geglu_forward +from liger_kernel.ops.backends._ascend.ops.group_norm import LigerGroupNormFunction +from liger_kernel.ops.backends._ascend.ops.group_norm import group_norm_backward +from liger_kernel.ops.backends._ascend.ops.group_norm import group_norm_forward +from liger_kernel.ops.backends._ascend.ops.grpo_loss import GrpoLossFunction +from liger_kernel.ops.backends._ascend.ops.grpo_loss import grpo_loss_backward_triton +from liger_kernel.ops.backends._ascend.ops.grpo_loss import grpo_loss_forward_triton +from liger_kernel.ops.backends._ascend.ops.jsd import LigerJSDFunction +from liger_kernel.ops.backends._ascend.ops.jsd import jsd_backward +from liger_kernel.ops.backends._ascend.ops.jsd import jsd_forward +from liger_kernel.ops.backends._ascend.ops.kl_div import LigerKLDivLossFunction +from liger_kernel.ops.backends._ascend.ops.kl_div import kldiv_backward_triton +from liger_kernel.ops.backends._ascend.ops.kl_div import kldiv_forward_triton +from liger_kernel.ops.backends._ascend.ops.layer_norm import LigerLayerNormFunction +from liger_kernel.ops.backends._ascend.ops.layer_norm import layer_norm_backward +from liger_kernel.ops.backends._ascend.ops.layer_norm import layer_norm_forward +from liger_kernel.ops.backends._ascend.ops.llama4_rope import LigerLlama4RopeFunction +from liger_kernel.ops.backends._ascend.ops.llama4_rope import llama4_rope_backward +from liger_kernel.ops.backends._ascend.ops.llama4_rope import llama4_rope_forward +from liger_kernel.ops.backends._ascend.ops.poly_norm import LigerPolyNormFunction +from liger_kernel.ops.backends._ascend.ops.poly_norm import poly_norm_backward +from liger_kernel.ops.backends._ascend.ops.poly_norm import poly_norm_forward +from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction +from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import qwen2vl_mrope_backward +from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import qwen2vl_mrope_forward +from liger_kernel.ops.backends._ascend.ops.rms_norm import LigerRMSNormFunction +from liger_kernel.ops.backends._ascend.ops.rms_norm import rms_norm_backward +from liger_kernel.ops.backends._ascend.ops.rms_norm import rms_norm_forward +from liger_kernel.ops.backends._ascend.ops.rope import LigerRopeFunction +from liger_kernel.ops.backends._ascend.ops.rope import rope_backward +from liger_kernel.ops.backends._ascend.ops.rope import rope_forward +from liger_kernel.ops.backends._ascend.ops.softmax import LigerSoftmaxFunction +from liger_kernel.ops.backends._ascend.ops.softmax import _softmax_backward +from liger_kernel.ops.backends._ascend.ops.softmax import _softmax_forward +from liger_kernel.ops.backends._ascend.ops.sparsemax import LigerSparsemaxFunction +from liger_kernel.ops.backends._ascend.ops.sparsemax import sparsemax_backward +from liger_kernel.ops.backends._ascend.ops.sparsemax import sparsemax_forward +from liger_kernel.ops.backends._ascend.ops.swiglu import LigerSiLUMulFunction +from liger_kernel.ops.backends._ascend.ops.swiglu import swiglu_backward +from liger_kernel.ops.backends._ascend.ops.swiglu import swiglu_forward +from liger_kernel.ops.backends._ascend.ops.tvd import LigerTVDLossFunction +from liger_kernel.ops.backends._ascend.ops.tvd import tv_distance_forward_triton +from liger_kernel.ops.backends._ascend.ops.tvd import tvd_backward_triton + +__all__ = [ + "LigerEmbeddingFunction", + "embedding_forward", + "embedding_backward", + "LigerFusedAddRMSNormFunction", + "fused_add_rms_norm_forward", + "fused_add_rms_norm_backward", + "LigerGELUMulFunction", + "geglu_forward", + "geglu_backward", + "LigerQwen2VLMRopeFunction", + "qwen2vl_mrope_forward", + "qwen2vl_mrope_backward", + "LigerRMSNormFunction", + "rms_norm_forward", + "rms_norm_backward", + "LigerRopeFunction", + "rope_forward", + "rope_backward", + "LigerSiLUMulFunction", + "swiglu_forward", + "swiglu_backward", + "LigerTVDLossFunction", + "tv_distance_forward_triton", + "tvd_backward_triton", + "LigerLlama4RopeFunction", + "llama4_rope_forward", + "llama4_rope_backward", + "LigerPolyNormFunction", + "poly_norm_forward", + "poly_norm_backward", + "LigerDyTFunction", + "liger_dyt_fwd", + "liger_dyt_bwd", + "LigerKLDivLossFunction", + "kldiv_forward_triton", + "kldiv_backward_triton", + "LigerLayerNormFunction", + "layer_norm_backward", + "layer_norm_forward", + "LigerSoftmaxFunction", + "_softmax_forward", + "_softmax_backward", + "LigerJSDFunction", + "jsd_forward", + "jsd_backward", + "LigerCrossEntropyFunction", + "cross_entropy_backward", + "cross_entropy_forward", + "GrpoLossFunction", + "grpo_loss_forward_triton", + "grpo_loss_backward_triton", + "LigerFusedLinearJSDFunction", + "fused_linear_jsd_forward", + "fused_linear_jsd_backward", + "LigerGroupNormFunction", + "group_norm_forward", + "group_norm_backward", + "LigerSparsemaxFunction", + "sparsemax_forward", + "sparsemax_backward", +] diff --git a/src/liger_kernel/ops/backends/_ascend/ops/cross_entropy.py b/src/liger_kernel/ops/backends/_ascend/ops/cross_entropy.py new file mode 100755 index 0000000000000000000000000000000000000000..7bc512a0339b6be4129a582512c910d5a792a0d8 --- /dev/null +++ b/src/liger_kernel/ops/backends/_ascend/ops/cross_entropy.py @@ -0,0 +1,568 @@ +from typing import Optional + +import torch +import triton +import triton.language as tl + +from triton.language.math import tanh + +from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy +from liger_kernel.ops.utils import element_mul_kernel +from liger_kernel.ops.utils import get_npu_core_count + + +@triton.jit +def liger_cross_entropy_kernel( + X_ptr, + X_stride, + Y_ptr, + Y_stride, + weight_ptr, + loss_ptr, + z_loss_ptr, + loss_stride, + token_accuracy_ptr, + token_accuracy_stride, + predicted_tokens_ptr, + predicted_tokens_stride, + n_cols, + n_rows, + n_non_ignore, + sum_non_ignore_weight, + weight_sum, + ignore_index, + lse_square_scale: tl.constexpr, + label_smoothing: tl.constexpr, + reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time + softcap, + RETURN_Z_LOSS: tl.constexpr, + RETURN_TOKEN_ACCURACY: tl.constexpr, + RETURN_PREDICTED_TOKENS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_SOFTCAPPING: tl.constexpr, + HAS_GRADIENTS: tl.constexpr, +): + """ + This kernel computes both cross entropy loss and the gradient of the input. + We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math. + + Parameters: + X_ptr: Pointer to input tensor. + X_stride (int): The stride of the input tensor. + Y_ptr: Pointer to target tensor. + Y_stride (int): The stride of the target tensor. + weight_ptr: Pointer to weight tensor. + loss_ptr: Pointer to tensor to store the loss. + z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0. + loss_stride (int): The stride of the loss tensor. + token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0. + token_accuracy_stride (int): The stride of the token accuracy tensor. + n_cols (int): The number of columns in the input tensor. + n_rows (int): The total number of rows to process. + n_non_ignore (float): The number of non-ignored elements in the batch. + sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch. + weight_sum (float): The sum of weight tensor. + ignore_index (int): The index to ignore in the target. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. + reduction (str): The string for the reduction to apply + softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap). + RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1. + RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1. + BLOCK_SIZE (int): The block size for Triton operations. + HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes. + HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not. + HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass. + """ + + # Grid-Stride Loop: each program processes multiple rows + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + + start_row = pid + stride = num_progs + + for row_idx in range(start_row, n_rows, stride): + # https://github.com/triton-lang/triton/issues/1058 + # If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64 + program_id = row_idx.to(tl.int64) + + # 1. Load Y_ptr first because if the target is ignore_index, we can return right away + Y_ptr_offset = program_id * Y_stride + y = tl.load(Y_ptr + Y_ptr_offset) + + # 2. locate the start index + X_ptr_offset = program_id * X_stride + + is_ignored = y == ignore_index + + if is_ignored: + # set all X_ptr as 0 + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + tl.store(X_ptr + X_ptr_offset + X_offsets, 0.0, mask=X_offsets < n_cols) + # For ignored tokens, set token accuracy to 0 + if RETURN_TOKEN_ACCURACY: + token_accuracy_ptr_offset = program_id * token_accuracy_stride + tl.store(token_accuracy_ptr + token_accuracy_ptr_offset, 0.0) + if RETURN_PREDICTED_TOKENS: + predicted_tokens_ptr_offset = program_id * predicted_tokens_stride + tl.store(predicted_tokens_ptr + predicted_tokens_ptr_offset, -1) + else: + loss_ptr_offset = program_id * loss_stride + if RETURN_Z_LOSS: + z_loss_ptr_offset = program_id * loss_stride + if RETURN_TOKEN_ACCURACY: + token_accuracy_ptr_offset = program_id * token_accuracy_stride + if RETURN_PREDICTED_TOKENS: + predicted_tokens_ptr_offset = program_id * predicted_tokens_stride + + if HAS_WEIGHT: + weight_y = tl.load(weight_ptr + y).cast(tl.float32) + + # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax) + # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867 + + # 3. [Online softmax] first pass: find max + sum + m = float("-inf") # m is the max value. use the notation from the paper + d = 0.0 # d is the sum. use the notation from the paper + argmax_idx = 0 # Track the index of the maximum value for token accuracy / predicted tokens computation + ori_X_y = tl.load(X_ptr + X_ptr_offset + y).cast( + tl.float32 + ) # we need to store the original value of X_y for the loss calculation + if HAS_SOFTCAPPING: + ori_X_y = softcap * tanh(ori_X_y / softcap) + + # Label smoothing is a general case of normal cross entropy + # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310 + scaled_x_sum = 0.0 + eps = label_smoothing / n_cols + + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load( + X_ptr + X_ptr_offset + X_offsets, + mask=X_offsets < n_cols, + other=float("-inf"), + # Ensure float32 precision for softmax calculation + ).cast(tl.float32) + if HAS_SOFTCAPPING: + X_block = softcap * tanh(X_block / softcap) + block_max = tl.max(X_block) + + # Track argmax for accuracy / predicted tokens computation + if RETURN_TOKEN_ACCURACY or RETURN_PREDICTED_TOKENS: + # Find the index of the maximum value in this block + is_max_mask = X_block == block_max + # Mask out invalid indices with a value larger than n_cols + masked_offsets = tl.where(is_max_mask, X_offsets, n_cols) + # Get the first (smallest) index where max occurs + current_block_argmax_idx = tl.min(masked_offsets) + + is_new_max = block_max > m + argmax_idx = tl.where(is_new_max, current_block_argmax_idx, argmax_idx) + + if label_smoothing > 0: + # scale X beforehand to avoid overflow + if HAS_WEIGHT: + weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0)) + else: + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) + m_new = tl.maximum(m, block_max) + d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) + m = m_new + + # log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X))))) + # = log (e^(max(X)) * sum(e ^ (X_i - max(X)))) + # = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d + lse = m + tl.log(d) + + # 4. [Online Softmax] Second pass: compute gradients + # For 'mean' reduction, gradients are normalized by number of non-ignored elements (N) + # dx_y = (softmax(x_y) - 1) / N + # dx_i = softmax(x_i) / N, i != y + # For label smoothing: + # dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y + # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N + # = dx_i - (1 - label_smoothing) / N + # With Z loss: + # dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y + # dx_y = dx_i - (1 - label_smoothing) / N + # For 'sum' reduction, no normalization is applied: + # dx_y = softmax(x_y) - 1 + # dx_i = softmax(x_i), for i ≠ y + if HAS_GRADIENTS: + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load( + X_ptr + X_ptr_offset + X_offsets, + mask=X_offsets < n_cols, + other=float("-inf"), + # Ensure float32 precision for softmax calculation + ).cast(tl.float32) + if HAS_SOFTCAPPING: + intermediate = tanh(X_block / softcap) + X_block = softcap * intermediate + + if not HAS_WEIGHT: + # softmax(x_i) + X_block = tl.exp(X_block - m) / d + # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) + X_block += 2 * lse_square_scale * lse * X_block + # smoothing term + X_block += -eps + # special handle dx_y + X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing)) + # reduction scale + if reduction == "mean": + X_block = X_block / n_non_ignore + else: + weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) + softmax_X = tl.exp(X_block - m) / d + # derivative of original_loss + dloss_ori = (1 - label_smoothing) * softmax_X + # specially handle dx_y + dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing)) + dloss_ori = dloss_ori * weight_y + # derivative of smooth_loss + dloss_smooth = eps * (-weight_block + softmax_X * weight_sum) + # derivative of z-loss + dz_loss = 2 * lse_square_scale * lse * softmax_X + # reduction scale + if reduction == "mean": + dloss_ori = dloss_ori / sum_non_ignore_weight + dloss_smooth = dloss_smooth / sum_non_ignore_weight + # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. + dz_loss = dz_loss / n_non_ignore + # derivative of total_loss + X_block = dloss_ori + dloss_smooth + dz_loss + + # chain rule softcapping + # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap)) + if HAS_SOFTCAPPING: + X_block = X_block * (1 - intermediate * intermediate) + + tl.store(X_ptr + X_ptr_offset + X_offsets, X_block, mask=X_offsets < n_cols) + + # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in + # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34 + tl.debug_barrier() + + # 5. Calculate the loss + + # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) + # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) + # = X_y - m - log d = X_y - lse + # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1 + # So we can safely calculate log (softmax(X_y)) without overflow + loss = lse - ori_X_y + if HAS_WEIGHT: + loss = weight_y * loss + + # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps + # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) + # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) + # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: + # = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd)) + # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 + # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 + # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087 + if label_smoothing > 0: + if HAS_WEIGHT: + smooth_loss = scaled_x_sum + eps * lse * weight_sum + else: + smooth_loss = scaled_x_sum + label_smoothing * lse + loss = loss * (1 - label_smoothing) + smooth_loss + + # An auxiliary loss, z_loss + # Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html + z_loss = lse_square_scale * lse * lse + # Normalize the loss by the number of non-ignored elements if reduction is "mean" + if reduction == "mean": + if HAS_WEIGHT: + loss = loss / sum_non_ignore_weight + else: + loss = loss / n_non_ignore + # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. + z_loss = z_loss / n_non_ignore + loss += z_loss + + tl.store(loss_ptr + loss_ptr_offset, loss) + if RETURN_Z_LOSS: + tl.store(z_loss_ptr + z_loss_ptr_offset, z_loss) + if RETURN_TOKEN_ACCURACY: + # Store 1.0 if prediction is correct, 0.0 otherwise + is_correct = 1.0 if argmax_idx == y else 0.0 + tl.store(token_accuracy_ptr + token_accuracy_ptr_offset, is_correct) + if RETURN_PREDICTED_TOKENS: + tl.store(predicted_tokens_ptr + predicted_tokens_ptr_offset, argmax_idx) + + +def get_optimal_block_size(n_cols, has_gradients=True): + """ + Calculate optimal Block Size using compute_default_tiling_strategy + """ + # Cross entropy is more memory intensive than swiglu because it needs softmax computation + # Forward needs online softmax calculation, backward needs more memory for intermediate variables + # 10.0 and 16.0 are empirical values based on Atlas 800I A2 UB (192KB) + multiplier = 12.0 if has_gradients else 8.0 + + # Call calculation function + # Treat input as 1D (n_cols,), only tiling on dim 0 + tile_shapes = compute_default_tiling_strategy( + safety_margin=0.9, dtype_size=4, memory_multiplier=multiplier, shapes=((n_cols,),), tiling_dims=(0,) + ) + + # Parse result + if tile_shapes and len(tile_shapes) > 0: + block_size = tile_shapes[0][0] + return block_size + else: + return 2048 + + +def cross_entropy_forward( + _input, + target, + weight, + ignore_index, + lse_square_scale, + label_smoothing, + reduction, + softcap, + return_z_loss, + return_token_accuracy=False, + return_predicted_tokens=False, +): + assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}" + assert isinstance(return_token_accuracy, bool), ( + f"return_token_accuracy must be True or False. Got: {return_token_accuracy}" + ) + assert isinstance(return_predicted_tokens, bool), ( + f"return_predicted_tokens must be True or False. Got: {return_predicted_tokens}" + ) + + BT, V = _input.shape + n_rows = BT + + BLOCK_SIZE = get_optimal_block_size(V, has_gradients=_input.requires_grad) + + # unreduced loss + loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) + z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None + token_accuracy_1d = ( + torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None + ) + predicted_tokens_1d = ( + torch.full((n_rows,), -1, dtype=torch.int64, device=_input.device) if return_predicted_tokens else None + ) + + target_mask = target != ignore_index + n_non_ignore = target_mask.sum().item() + assert (target * target_mask).max() < _input.shape[-1], ( + f"Target {target.max()} is out of bounds. Expected < {_input.shape[-1]}" + ) + assert (target * target_mask).min() >= 0, f"Target {target.min()} is out of bounds. Expected >= 0" + sum_non_ignore_weight = n_non_ignore + weight_sum = 0.0 + if weight is not None: + assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}" + assert torch.is_floating_point(weight), ( + f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}" + ) + sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item() + weight_sum = weight.sum().item() + # ensure weight is contiguous + if weight.stride(-1) != 1: + weight = weight.contiguous() + + # ensure _input and target are contiguous in the last dimension + if _input.stride(-1) != 1: + _input = _input.contiguous() + if target.stride(-1) != 1: + target = target.contiguous() + + # NPU-optimized grid configuration + num_cores = get_npu_core_count() + grid_size = min(num_cores, n_rows) + + # Here we use a trick to store X_ptr gradient in X_ptr so we can save memory + liger_cross_entropy_kernel[(grid_size,)]( + X_ptr=_input, + X_stride=_input.stride(-2), + Y_ptr=target, + Y_stride=target.stride(-1), # always 1 + weight_ptr=weight, # dummy if None + loss_ptr=loss_1d, + z_loss_ptr=z_loss_1d, + loss_stride=loss_1d.stride(-1), # always 1 + token_accuracy_ptr=token_accuracy_1d, + token_accuracy_stride=token_accuracy_1d.stride(-1) + if return_token_accuracy + else 0, # always 1 if accuracy is enabled + predicted_tokens_ptr=predicted_tokens_1d, + predicted_tokens_stride=predicted_tokens_1d.stride(-1) + if return_predicted_tokens + else 0, # always 1 if predicted tokens is enabled + n_cols=V, + n_rows=n_rows, + n_non_ignore=n_non_ignore, + sum_non_ignore_weight=sum_non_ignore_weight, + ignore_index=ignore_index, + weight_sum=weight_sum, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + reduction=reduction, + softcap=softcap, + RETURN_Z_LOSS=return_z_loss, + RETURN_TOKEN_ACCURACY=return_token_accuracy, + RETURN_PREDICTED_TOKENS=return_predicted_tokens, + BLOCK_SIZE=BLOCK_SIZE, + HAS_WEIGHT=True if weight is not None else False, + HAS_SOFTCAPPING=True if softcap is not None else False, + HAS_GRADIENTS=_input.requires_grad, + ) + + if reduction == "none": + loss = loss_1d + z_loss = z_loss_1d if return_z_loss else None + token_accuracy = token_accuracy_1d if return_token_accuracy else None + else: + loss = torch.sum(loss_1d) + z_loss = torch.sum(z_loss_1d) if return_z_loss else None + # For accuracy, we compute the mean across all non-ignored tokens + token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None + + predicted_tokens = predicted_tokens_1d if return_predicted_tokens else None + + return loss, z_loss, token_accuracy, predicted_tokens, _input + + +def cross_entropy_backward(_input, grad_output): + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + pass + # If reduction is 'none' + elif grad_output.ndim > 0: + _input = _input * grad_output.unsqueeze(dim=1) + # If reduction is ['mean', 'sum'], grad_output is just a scalar + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + else: + BT, V = _input.shape + n_rows = BT + BLOCK_SIZE = min(2048, triton.next_power_of_2(V)) + + element_mul_kernel[(n_rows,)]( + _input, + _input.stride(-2), + grad_output, + V, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return _input + + +class LigerCrossEntropyFunction(torch.autograd.Function): + """ + This class implements a custom autograd function for the Liger Cross Entropy loss. + It overrides the forward and backward methods of the torch.autograd.Function class. + """ + + @staticmethod + def forward( + ctx, + _input: torch.Tensor, + target: torch.Tensor, + weight: Optional[torch.FloatTensor], + ignore_index: int = -100, + lse_square_scale: float = 0.0, + label_smoothing: float = 0.0, + reduction: str = "mean", + softcap: Optional[float] = None, + return_z_loss: bool = False, + return_token_accuracy: bool = False, + return_predicted_tokens: bool = False, + ): + """ + The forward pass of the Liger Cross Entropy loss. + + Parameters: + ctx : The context object. + _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size. + target (tensor): The target tensor of shape (BT) where each value is in [0, V-1]. + weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype + ignore_index (int): The index to ignore in the target. + lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + reduction (str): The reduction to apply to the output: "none" | "mean | "sum". + softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap). + return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy, predicted_tokens) instead of (loss, None, None, None). Default: `False` + return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False` + return_predicted_tokens (bool): When `return_predicted_tokens` is `True`, returns per-token predicted class indices (argmax) without materializing logits. Default: `False` + + Returns: + tuple: A tuple with the computed losses, accuracy, and predicted tokens: (loss, z_loss, token_accuracy, predicted_tokens). z_loss, token_accuracy, and predicted_tokens are None if not requested. + """ + input_requires_grad = _input.requires_grad + + loss, z_loss, token_accuracy, predicted_tokens, _input = cross_entropy_forward( + _input, + target, + weight, + ignore_index, + lse_square_scale, + label_smoothing, + reduction, + softcap, + return_z_loss, + return_token_accuracy, + return_predicted_tokens, + ) + if input_requires_grad: + ctx.save_for_backward(_input.detach()) + ctx.return_z_loss = return_z_loss + ctx.return_token_accuracy = return_token_accuracy + ctx.return_predicted_tokens = return_predicted_tokens + + return loss, z_loss, token_accuracy, predicted_tokens + + @staticmethod + def backward(ctx, grad_output, grad_output2, grad_output3, grad_output4): + """ + The backward pass of the Liger Cross Entropy loss. + + Parameters: + ctx : The context object with saved tensors. + grad_output (tensor): The tensor containing the gradient of the loss with respect to the output. + grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging). + grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics). + grad_output4 (tensor): No use. Gradient for predicted_tokens (not used as predicted_tokens is only for metrics). + Returns: + tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. + """ + if ctx.return_z_loss: + del grad_output2 # z_loss is only for logging + if ctx.return_token_accuracy: + del grad_output3 # token_accuracy is only for metrics + if ctx.return_predicted_tokens: + del grad_output4 # predicted_tokens is only for metrics + + (_input,) = ctx.saved_tensors + _input = cross_entropy_backward(_input, grad_output) + return ( + _input, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) diff --git a/src/liger_kernel/ops/backends/_ascend/ops/dyt.py b/src/liger_kernel/ops/backends/_ascend/ops/dyt.py new file mode 100755 index 0000000000000000000000000000000000000000..cdcb1a327fd940ec3c3c734a60816cf130b9d856 --- /dev/null +++ b/src/liger_kernel/ops/backends/_ascend/ops/dyt.py @@ -0,0 +1,285 @@ +import torch +import triton +import triton.language as tl + +from triton.language.math import tanh + +from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy +from liger_kernel.ops.utils import ensure_contiguous +from liger_kernel.ops.utils import get_npu_core_count + +# ----------------------------------------------------------------------------- +# Forward Kernel +# ----------------------------------------------------------------------------- + + +@triton.jit +def _dyt_fwd_kernel( + X, + Y, + Alpha, + Gamma, + Beta, + HAVE_BETA: tl.constexpr, + M: tl.constexpr, + N: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """ + Forward kernel for DYT: y = tanh(α·x) · γ + β + + Grid: (num_col_blocks, num_row_programs) + Each program processes multiple rows using grid-stride loop + """ + pid_n = tl.program_id(0) + pid_m = tl.program_id(1) + num_row_programs = tl.num_programs(1) + + col_start = pid_n * BLOCK_N + col_offsets = col_start + tl.arange(0, BLOCK_N) + col_mask = col_offsets < N + + alpha = tl.load(Alpha).to(tl.float32) + gamma = tl.load(Gamma + col_offsets, mask=col_mask, other=0.0).to(tl.float32) + if HAVE_BETA: + beta = tl.load(Beta + col_offsets, mask=col_mask, other=0.0).to(tl.float32) + + # Grid-stride loop over rows + for row_idx in range(pid_m, M, num_row_programs): + row_offset = row_idx * N + + x = tl.load(X + row_offset + col_offsets, mask=col_mask, other=0.0).to(tl.float32) + + # Compute: y = tanh(α·x) · γ + β + tanh_x = tanh(alpha * x) + y = tanh_x * gamma + + if HAVE_BETA: + y += beta + + tl.store(Y + row_offset + col_offsets, y, mask=col_mask) + + +# ----------------------------------------------------------------------------- +# Backward Kernel +# ----------------------------------------------------------------------------- + + +@triton.jit +def _dyt_bwd_kernel( + DY, + DX, + DA, + DG, + DB, + X, + Alpha, + Gamma, + HAVE_BETA: tl.constexpr, + M: tl.constexpr, + N: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """ + Backward kernel for DYT + + Grid: (num_col_blocks, num_row_programs) + Each program processes multiple rows using grid-stride loop + Accumulates gradients in local buffers, then stores to global memory + """ + pid_n = tl.program_id(0) + pid_m = tl.program_id(1) + num_row_programs = tl.num_programs(1) + + col_start = pid_n * BLOCK_N + col_offsets = col_start + tl.arange(0, BLOCK_N) + col_mask = col_offsets < N + + alpha = tl.load(Alpha).to(tl.float32) + gamma = tl.load(Gamma + col_offsets, mask=col_mask, other=0.0).to(tl.float32) + + da_vec = tl.zeros((BLOCK_N,), dtype=tl.float32) + dg_acc = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAVE_BETA: + db_acc = tl.zeros((BLOCK_N,), dtype=tl.float32) + + # Grid-stride loop over rows + for row_idx in range(pid_m, M, num_row_programs): + row_offset = row_idx * N + + x = tl.load(X + row_offset + col_offsets, mask=col_mask, other=0.0).to(tl.float32) + dy = tl.load(DY + row_offset + col_offsets, mask=col_mask, other=0.0).to(tl.float32) + + tanh_x = tanh(alpha * x) + + if HAVE_BETA: + db_acc += dy + + dg_acc += dy * tanh_x + + # Compute intermediate: tmp = (1 - tanh²) · dy · γ + tmp = (1.0 - tanh_x * tanh_x) * dy * gamma + + # Accumulate dα = Σ(x · tmp) + da_vec += x * tmp + + # Compute dx = α · tmp + dx = alpha * tmp + tl.store(DX + row_offset + col_offsets, dx, mask=col_mask) + + da_acc = tl.sum(da_vec, 0) + da_offset = pid_m * triton.cdiv(N, BLOCK_N) + pid_n + tl.store(DA + da_offset, da_acc) + + dg_offset = pid_m * N + col_offsets + tl.store(DG + dg_offset, dg_acc, mask=col_mask) + + if HAVE_BETA: + db_offset = pid_m * N + col_offsets + tl.store(DB + db_offset, db_acc, mask=col_mask) + + +def get_optimal_block_size(total_elements, is_backward=False): + """ + Calculate optimal Block Size using compute_default_tiling_strategy + """ + multiplier = 8.0 if is_backward else 4.0 + + tile_shapes = compute_default_tiling_strategy( + safety_margin=0.9, dtype_size=4, memory_multiplier=multiplier, shapes=((total_elements,),), tiling_dims=(0,) + ) + + if tile_shapes and len(tile_shapes) > 0: + block_size = tile_shapes[0][0] + return block_size + else: + return 2048 + + +def _compute_grid_size(n_cols, n_rows, block_n): + """ + Compute grid size to avoid launching idle programs + + Args: + n_cols: Number of columns + n_rows: Number of rows + block_n: Block size for column dimension + + Returns: + (num_col_blocks, num_row_programs) + """ + num_cores = get_npu_core_count() + num_col_blocks = triton.cdiv(n_cols, block_n) + num_row_blocks = n_rows + + num_row_programs = min(max(1, (num_cores // num_col_blocks)), num_row_blocks) + + return num_col_blocks, num_row_programs + + +# ----------------------------------------------------------------------------- +# Python Wrapper Functions +# ----------------------------------------------------------------------------- + + +def liger_dyt_fwd(x, alpha, gamma, beta): + """ + Forward pass of DYT: y = tanh(α·x) · γ + β + + Args: + x: Input tensor of shape [..., N] + alpha: Scalar parameter + gamma: Vector parameter of shape [N] + beta: Vector parameter of shape [N] (optional) + + Returns: + y: Output tensor of same shape as x + """ + assert x.is_contiguous() + HAVE_BETA = beta is not None + + # Flatten to 2D + input_shape = x.shape + x = x.view(-1, input_shape[-1]) + M, N = x.shape + + # Allocate output + y = torch.empty_like(x) + + block_n = get_optimal_block_size(N, is_backward=False) + + # Compute grid size + num_col_blocks, num_row_programs = _compute_grid_size(N, M, block_n) + grid = (num_col_blocks, num_row_programs) + + # Launch kernel + _dyt_fwd_kernel[grid](x, y, alpha, gamma, beta, HAVE_BETA, M, N, BLOCK_N=block_n) + + return y.view(input_shape) + + +def liger_dyt_bwd(dy, x, alpha, gamma, beta): + """ + Backward pass of DYT + + Args: + dy: Upstream gradient of shape [..., N] + x: Input tensor of shape [..., N] + alpha: Scalar parameter + gamma: Vector parameter of shape [N] + beta: Vector parameter of shape [N] (optional) + + Returns: + dx: Gradient w.r.t. x + dalpha: Gradient w.r.t. alpha + dgamma: Gradient w.r.t. gamma + dbeta: Gradient w.r.t. beta (or None) + """ + assert dy.is_contiguous() + HAVE_BETA = beta is not None + + # Flatten to 2D + input_shape = x.shape + x = x.view(-1, input_shape[-1]) + dy = dy.view(-1, input_shape[-1]) + M, N = x.shape + + block_n = get_optimal_block_size(N, is_backward=True) + + # Compute grid size + num_col_blocks, num_row_programs = _compute_grid_size(N, M, block_n) + grid = (num_col_blocks, num_row_programs) + + da = torch.zeros(num_row_programs, triton.cdiv(N, block_n), dtype=torch.float32, device=x.device) + dg = torch.empty(num_row_programs, N, dtype=torch.float32, device=x.device) + db = torch.empty(num_row_programs, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None + dx = torch.empty_like(dy) + + _dyt_bwd_kernel[grid](dy, dx, da, dg, db, x, alpha, gamma, HAVE_BETA, M, N, BLOCK_N=block_n) + + da = da.sum().to(x.dtype).unsqueeze(0) + dg = dg.sum(0).to(gamma.dtype) + db = db.sum(0).to(x.dtype) if HAVE_BETA else None + + return dx.view(input_shape), da, dg, db + + +# ----------------------------------------------------------------------------- +# Autograd Function +# ----------------------------------------------------------------------------- + + +class LigerDyTFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, x, alpha, gamma, beta): + y = liger_dyt_fwd(x, alpha, gamma, beta) + ctx.save_for_backward(x, alpha, gamma, beta) + return y + + @staticmethod + @ensure_contiguous + def backward(ctx, dy): + x, alpha, gamma, beta = ctx.saved_tensors + dx, dalpha, dgamma, dbeta = liger_dyt_bwd(dy, x, alpha, gamma, beta) + return dx, dalpha, dgamma, dbeta diff --git a/src/liger_kernel/ops/backends/_ascend/ops/embedding.py b/src/liger_kernel/ops/backends/_ascend/ops/embedding.py new file mode 100755 index 0000000000000000000000000000000000000000..dc6016d41eb1b45f5aa0dfb3b5098269e56713ee --- /dev/null +++ b/src/liger_kernel/ops/backends/_ascend/ops/embedding.py @@ -0,0 +1,210 @@ +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy +from liger_kernel.ops.utils import ensure_contiguous +from liger_kernel.ops.utils import get_npu_core_count + + +@triton.jit +def embedding_forward_kernel( + embeddings_ptr, + indices_ptr, + output_ptr, + n_elements, + embedding_dim: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + + grid_m = tl.cdiv(n_elements, BLOCK_SIZE_M) + grid_n = tl.cdiv(embedding_dim, BLOCK_SIZE_N) + total_2d_blocks = grid_m * grid_n + + for block_idx in tl.range(pid, total_2d_blocks, num_progs): + block_m = block_idx // grid_n + block_n = block_idx % grid_n + + start_m = block_m * BLOCK_SIZE_M + start_n = block_n * BLOCK_SIZE_N + + offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M) + mask_m = offsets_m < n_elements + + indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0) + + offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N) + mask_n = offsets_n < embedding_dim + + block_mask = mask_m[:, None] & mask_n[None, :] + + embedding_offsets = indices[:, None] * embedding_dim + offsets_n[None, :] + embeddings = tl.load( + embeddings_ptr + embedding_offsets, + mask=block_mask, + other=0.0, + ) + + output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :] + tl.store( + output_ptr + output_offsets, + embeddings, + mask=block_mask, + ) + + +@triton.jit +def embedding_backward_kernel( + grad_output_ptr, + grad_weight_ptr, + indices_ptr, + n_elements, + embedding_dim: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + + grid_m = tl.cdiv(n_elements, BLOCK_SIZE_M) + grid_n = tl.cdiv(embedding_dim, BLOCK_SIZE_N) + total_2d_blocks = grid_m * grid_n + + for block_idx in tl.range(pid, total_2d_blocks, num_progs): + block_m = block_idx // grid_n + block_n = block_idx % grid_n + + start_m = block_m * BLOCK_SIZE_M + start_n = block_n * BLOCK_SIZE_N + + offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M) + mask_m = offsets_m < n_elements + + indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0) + + offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N) + mask_n = offsets_n < embedding_dim + + block_mask = mask_m[:, None] & mask_n[None, :] + + grad_output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :] + grad_output = tl.load( + grad_output_ptr + grad_output_offsets, + mask=block_mask, + other=0.0, + ) + + grad_weight_offsets = indices[:, None] * embedding_dim + offsets_n[None, :] + tl.atomic_add( + grad_weight_ptr + grad_weight_offsets, + grad_output, + mask=block_mask, + ) + + +def get_optimal_block_size(total_elements, dtype_size, BLOCK_SIZE_N: tl.constexpr): + # 1. Set Memory Multiplier + # 3.0 are empirical values based on Atlas 800I A2 UB (192KB) + # embedding_offsets, embedding_offsets : BLOCK_SIZE_N * BLOCK_SIZE_M (total 2 * BLOCK_SIZE_N * BLOCK_SIZE_M) + # Reserve a unit of space for the remaining one-dimensional ub to occupy. + # A conservative estimate of the total space occupation is 3 * BLOCK_SIZE_N * BLOCK_SIZE_M + multiplier = 3.0 + + # 2. Call calculation function + # Treat input as 1D (total_elements,), only tiling on dim 0 + tile_shapes = compute_default_tiling_strategy( + safety_margin=0.9, + dtype_size=dtype_size, + memory_multiplier=multiplier, + shapes=((total_elements, BLOCK_SIZE_N),), + tiling_dims=(0,), + ) + + # 3. Parse result + if tile_shapes and len(tile_shapes) > 0: + block_size = tile_shapes[0][0] + return block_size + else: + return triton.next_power_of_2(min(128, total_elements)) + + +def embedding_forward(embeddings, indices): + ori_shape = indices.shape + indices = indices.view(-1) + + n_elements = indices.numel() + embedding_dim = embeddings.shape[1] + output = torch.empty( + indices.shape[0], + embeddings.shape[1], + device=indices.device, + dtype=embeddings.dtype, + ) + + # Due to the involvement of two-dimensional partitioning, + # the sizes of block_m and block_n in the ub space will influence each other. + # Considering that embedding_dim is usually relatively smaller in most cases, + # a value is first assigned to block_n, and then the largest possible block_m is used. + BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim)) + BLOCK_SIZE_M = get_optimal_block_size(n_elements, embeddings.element_size(), BLOCK_SIZE_N) + num_cores = get_npu_core_count() + total_blocks = triton.cdiv(n_elements, BLOCK_SIZE_M) * triton.cdiv(embedding_dim, BLOCK_SIZE_N) + grid = min(num_cores, total_blocks) + + embedding_forward_kernel[(grid,)]( + embeddings, + indices, + output, + n_elements, + embedding_dim=embedding_dim, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + ) + + return output.view(*ori_shape, -1) + + +def embedding_backward(embeddings, indices, grad_output): + grad_output = grad_output.contiguous().view(-1, embeddings.shape[1]) + + grad_weight = torch.zeros_like(embeddings) + + n_elements = indices.numel() + embedding_dim = embeddings.shape[1] + BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim)) + BLOCK_SIZE_M = get_optimal_block_size(n_elements, embeddings.element_size(), BLOCK_SIZE_N) + num_cores = get_npu_core_count() + total_blocks = triton.cdiv(n_elements, BLOCK_SIZE_M) * triton.cdiv(embedding_dim, BLOCK_SIZE_N) + grid = min(num_cores, total_blocks) + + embedding_backward_kernel[(grid,)]( + grad_output, + grad_weight, + indices, + n_elements, + embedding_dim=embedding_dim, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + ) + + return grad_weight + + +class LigerEmbeddingFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, embeddings: torch.Tensor, indices: torch.Tensor): + output = embedding_forward(embeddings, indices) + ctx.save_for_backward(indices, embeddings) + return output + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output: torch.Tensor): + indices, embeddings = ctx.saved_tensors + grad_weight = embedding_backward(embeddings, indices, grad_output) + + return grad_weight, None diff --git a/src/liger_kernel/ops/backends/_ascend/ops/fused_add_rms_norm.py b/src/liger_kernel/ops/backends/_ascend/ops/fused_add_rms_norm.py new file mode 100755 index 0000000000000000000000000000000000000000..07e4c0db0d5c3e7ba8319a30a727edb604698cb3 --- /dev/null +++ b/src/liger_kernel/ops/backends/_ascend/ops/fused_add_rms_norm.py @@ -0,0 +1,781 @@ +import torch +import triton +import triton.language as tl + +from triton.language.math import rsqrt + +from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy +from liger_kernel.ops.utils import ensure_contiguous +from liger_kernel.ops.utils import get_npu_core_count +from liger_kernel.ops.utils import torch_to_triton_dtype + +_CASTING_MODE_NONE: tl.constexpr = tl.constexpr(-1) +_CASTING_MODE_LLAMA: tl.constexpr = tl.constexpr(0) +_CASTING_MODE_GEMMA: tl.constexpr = tl.constexpr(1) + + +def torch_dtype_to_triton(dtype): + mapping = { + torch.float32: tl.float32, + torch.bfloat16: tl.bfloat16, + } + return mapping.get(dtype, tl.float32) + + +# ----------------------------------------------------------------------------- +# Forward Kernel - No Tiling (for n_cols <= 2048) +# ----------------------------------------------------------------------------- + + +@triton.jit +def _fused_add_rms_norm_forward_kernel_no_tiling( + Y_ptr, + Y_row_stride, + S_ptr, # output residual + S_row_stride, + X_ptr, + X_row_stride, + R_ptr, # input residual + R_row_stride, + W_ptr, + RSTD_ptr, + RSTD_row_stride, + n_rows, + n_cols, + eps, + offset, + casting_mode: tl.constexpr, + X_DTYPE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + """ + NPU-optimized fused_add_rms_norm forward kernel for small n_cols (< 2048). + + Performance optimizations: + 1. Keep S_row in registers, avoid reload from memory + 2. Minimize type conversions by careful ordering + 3. Use optimal cache policies + 4. Preload W_row while computing rstd (instruction-level parallelism) + 5. Use 2D vector loading to maximize UB utilization (e.g., (1,2048), (2,1024), (4,512)) + + Used when n_cols < 2048 to avoid the overhead of column blocking. + """ + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + + if casting_mode == _CASTING_MODE_NONE: + eps = eps.to(X_DTYPE) + offset = offset.to(X_DTYPE) + + # Grid-stride loop setup for 2D blocks + grid_stride = num_progs * BLOCK_SIZE_M + num_iterations = tl.cdiv(n_rows, grid_stride) + + col_offsets = tl.arange(0, BLOCK_SIZE_N) + col_mask = col_offsets < n_cols + row_offsets = tl.arange(0, BLOCK_SIZE_M) + + W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0) + + # Grid-stride loop over row blocks + for i in tl.range(num_iterations): + row_idx = i * grid_stride + pid * BLOCK_SIZE_M + row_offsets + row_mask = row_idx < n_rows + block_mask = row_mask[:, None] & col_mask[None, :] + + # Load multiple rows at once using 2D indexing + X_rows = tl.load( + X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :], + mask=block_mask, + other=0.0, + eviction_policy="evict_first", + ) + R_rows = tl.load( + R_ptr + row_idx[:, None] * R_row_stride + col_offsets[None, :], + mask=block_mask, + other=0.0, + eviction_policy="evict_first", + ) + S_rows = X_rows + R_rows + + # Compute sum_square for all rows + if casting_mode == _CASTING_MODE_LLAMA or casting_mode == _CASTING_MODE_GEMMA: + S_rows = S_rows.to(tl.float32) + + sum_squares = tl.sum(tl.where(block_mask, S_rows * S_rows, 0.0), axis=1) + + # Compute rstd for all rows + mean_squares = sum_squares / n_cols + rstd_rows = rsqrt(mean_squares + eps) + + # Store S_rows and rstd_rows + tl.store( + S_ptr + row_idx[:, None] * S_row_stride + col_offsets[None, :], + S_rows, + mask=block_mask, + cache_modifier=".cg", + ) + tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd_rows, mask=row_mask) + + # Normalize and apply weight - optimized for each casting mode + if casting_mode == _CASTING_MODE_GEMMA: + Y_rows = ((S_rows * rstd_rows[:, None]) * (offset + W_row[None, :])).to(X_DTYPE) + elif casting_mode == _CASTING_MODE_LLAMA: + S_normalized = (S_rows * rstd_rows[:, None]).to(X_DTYPE) + Y_rows = S_normalized * (offset + W_row[None, :]) + else: + Y_rows = (S_rows * rstd_rows[:, None]) * (offset + W_row[None, :]) + + # Store results + tl.store(Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :], Y_rows, mask=block_mask) + + +# ----------------------------------------------------------------------------- +# Forward Kernel - With Tiling (for n_cols > 2048) +# ----------------------------------------------------------------------------- + + +@triton.jit +def _fused_add_rms_norm_forward_kernel_npu( + Y_ptr, + Y_row_stride, + S_ptr, # output residual + S_row_stride, + X_ptr, + X_row_stride, + R_ptr, # input residual + R_row_stride, + W_ptr, + RSTD_ptr, + RSTD_row_stride, + n_rows, + n_cols, + eps, + offset, + casting_mode: tl.constexpr, + X_DTYPE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + NPU-optimized fused_add_rms_norm forward kernel. + + This kernel processes rows using a grid-stride loop pattern: + 1. Each program handles multiple rows + 2. For each row, we process it in column chunks of BLOCK_SIZE_N + 3. Grid size is limited to NPU core count to avoid resource overflow + + This solves two problems: + 1. UB overflow when n_cols is too large (original kernel used n_cols as BLOCK_SIZE_N) + 2. Efficient multi-row processing within a single kernel launch + """ + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + num_col_blocks = tl.cdiv(n_cols, BLOCK_SIZE) + + if casting_mode == _CASTING_MODE_NONE: + eps = eps.to(X_DTYPE) + offset = offset.to(X_DTYPE) + + offsets = tl.arange(0, BLOCK_SIZE) + # Grid-stride loop over rows + for row_idx in tl.range(pid, n_rows, num_progs): + Y_row_ptr = Y_ptr + row_idx * Y_row_stride + S_row_ptr = S_ptr + row_idx * S_row_stride + X_row_ptr = X_ptr + row_idx * X_row_stride + R_row_ptr = R_ptr + row_idx * R_row_stride + RSTD_row_ptr = RSTD_ptr + row_idx * RSTD_row_stride + + # Accumulator for mean_square computation across all column blocks + sum_square = 0.0 + + # First pass: compute S_row = X_row + R_row and accumulate sum of squares + for col_block_idx in range(num_col_blocks): + col_start = col_block_idx * BLOCK_SIZE + col_offsets = col_start + offsets + mask = col_offsets < n_cols + + X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0, eviction_policy="evict_first") + R_block = tl.load(R_row_ptr + col_offsets, mask=mask, other=0.0, eviction_policy="evict_first") + S_block = X_block + R_block + + # Store S_row + tl.store(S_row_ptr + col_offsets, S_block, mask=mask, cache_modifier=".cg") + + if casting_mode == _CASTING_MODE_LLAMA or casting_mode == _CASTING_MODE_GEMMA: + S_block = S_block.to(tl.float32) + + # Accumulate sum of squares (only for valid elements) + sum_square += tl.sum(tl.where(mask, S_block * S_block, 0.0)) + + # Compute rstd for this row + mean_square = sum_square / n_cols + + rstd = rsqrt(mean_square + eps) + + # Store rstd + tl.store(RSTD_row_ptr, rstd) + + # Second pass: normalize and multiply by weight + for col_block_idx in range(num_col_blocks): + col_start = col_block_idx * BLOCK_SIZE + col_offsets = col_start + offsets + mask = col_offsets < n_cols + + # Load S_block (already computed in first pass) + S_block = tl.load(S_row_ptr + col_offsets, mask=mask, other=0.0, cache_modifier=".ca") + W_block = tl.load(W_ptr + col_offsets, mask=mask, other=0.0) + + # Apply casting based on mode + if casting_mode == _CASTING_MODE_GEMMA: + S_block = S_block.to(tl.float32) + W_block = W_block.to(tl.float32) + elif casting_mode == _CASTING_MODE_LLAMA: + S_block = S_block.to(tl.float32) + + # Normalize + S_block = S_block * rstd + + # Cast back for Llama mode before weight multiplication + if casting_mode == _CASTING_MODE_LLAMA: + S_block = S_block.to(X_DTYPE) + # Apply weight + Y_block = S_block * (offset + W_block) + + # Cast back for Gemma mode + if casting_mode == _CASTING_MODE_GEMMA: + Y_block = Y_block.to(X_DTYPE) + + # Store result + tl.store(Y_row_ptr + col_offsets, Y_block, mask=mask) + + +# ----------------------------------------------------------------------------- +# Backward Kernel - No Tiling (for n_cols < 2048) +# ----------------------------------------------------------------------------- + + +@triton.jit +def _fused_add_rms_norm_backward_kernel_no_tiling( + dY_ptr, + dY_row_stride, + dS_out_ptr, + dS_out_row_stride, + dX_ptr, + dX_row_stride, + X_ptr, + X_row_stride, + X_dtype: tl.constexpr, + W_ptr, + RSTD_ptr, + RSTD_row_stride, + dW_ptr, + dW_row_stride, + n_rows, + n_cols, + offset, + casting_mode: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + has_dS_out: tl.constexpr, +): + """ + NPU-optimized fused_add_rms_norm backward kernel for small n_cols (< 2048). + + Performance optimizations: + 1. Keep all data in registers, minimize conversions + 2. Reuse X_normalized (X * rstd) for both dX and dW + 3. Optimize computation order to reduce register pressure + 4. Combine operations where possible + 5. Use 2D vector loading to maximize UB utilization (e.g., (1,2048), (2,1024), (4,512)) + """ + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + + # Grid-stride loop setup for 2D blocks + grid_stride = num_progs * BLOCK_SIZE_M + num_iterations = tl.cdiv(n_rows, grid_stride) + + col_offsets = tl.arange(0, BLOCK_SIZE_N) + col_mask = col_offsets < n_cols + row_offsets = tl.arange(0, BLOCK_SIZE_M) + + # Load W once for all iterations + W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0) + W_offset = W_row + offset + + # Grid-stride loop over row blocks + for i in tl.range(num_iterations): + row_idx = i * grid_stride + pid * BLOCK_SIZE_M + row_offsets + row_mask = row_idx < n_rows + block_mask = row_mask[:, None] & col_mask[None, :] + + dY_rows = tl.load( + dY_ptr + row_idx[:, None] * dY_row_stride + col_offsets[None, :], + mask=block_mask, + other=0.0, + eviction_policy="evict_first", + ) + X_rows = tl.load( + X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :], + mask=block_mask, + other=0.0, + eviction_policy="evict_first", + ) + + # Load rstd for all rows in the block + rstd_rows = tl.load(RSTD_ptr + row_idx * RSTD_row_stride, mask=row_mask, other=0.0) + + # Convert X to fp32 once + X_rows = X_rows.to(tl.float32) + + # Compute X_normalized (reused in dX and dW) + X_normalized = X_rows * rstd_rows[:, None] + + # Compute m based on casting mode (optimized for each mode) + if casting_mode == _CASTING_MODE_LLAMA: + m_rows = (dY_rows * W_offset[None, :]).to(tl.float32) + # For dW in Llama mode, we need X_normalized in original dtype + X_normalized_for_dW = X_normalized.to(X_dtype) + elif casting_mode == _CASTING_MODE_GEMMA: + m_rows = dY_rows.to(tl.float32) * W_offset[None, :] + X_normalized_for_dW = X_normalized + else: + m_rows = dY_rows * W_offset[None, :] + X_normalized_for_dW = X_normalized + + # Compute sum(m * X) for correction factor + sum_m_X = tl.sum(tl.where(block_mask, m_rows * X_rows, 0.0), axis=1) + + # Compute correction factor + correction_factors = -(1.0 / n_cols) * rstd_rows * rstd_rows * sum_m_X + + # Compute dX = rstd * m + rstd * correction_factor * X + dX_rows = rstd_rows[:, None] * m_rows + rstd_rows[:, None] * correction_factors[:, None] * X_rows + + # Add dS_out gradient if present + if has_dS_out: + dS_out_rows = tl.load( + dS_out_ptr + row_idx[:, None] * dS_out_row_stride + col_offsets[None, :], mask=block_mask, other=0.0 + ) + dX_rows += dS_out_rows + + # Store dX + tl.store(dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :], dX_rows.to(X_dtype), mask=block_mask) + + # Compute dW contribution: dY * X_normalized + dW_rows = (dY_rows * X_normalized_for_dW).to(tl.float32) + + # Accumulate to per-program dW buffer + dW_row_ptr = dW_ptr + pid * dW_row_stride + existing_dW = tl.load(dW_row_ptr + col_offsets, mask=col_mask, other=0.0) + new_dW = existing_dW + tl.sum(tl.where(block_mask, dW_rows, 0.0), axis=0) + tl.store(dW_row_ptr + col_offsets, new_dW, mask=col_mask) + + +# ----------------------------------------------------------------------------- +# Backward Kernel - With Tiling (for n_cols > 2048) +# ----------------------------------------------------------------------------- + + +@triton.jit +def _fused_add_rms_norm_backward_kernel_npu( + dY_ptr, + dY_row_stride, + dS_out_ptr, + dS_out_row_stride, + dX_ptr, + dX_row_stride, + X_ptr, + X_row_stride, + X_dtype: tl.constexpr, + W_ptr, + RSTD_ptr, + RSTD_row_stride, + dW_ptr, + dW_row_stride, + n_rows, + n_cols, + offset, + casting_mode: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + has_dS_out: tl.constexpr, +): + """ + NPU-optimized fused_add_rms_norm backward kernel. + + Each program processes multiple rows using grid-stride loop. + For each row, we process columns in blocks to avoid UB overflow. + """ + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + + # Initialize dW accumulator (per-program, will be reduced later) + num_col_blocks = tl.cdiv(n_cols, BLOCK_SIZE) + offsets = tl.arange(0, BLOCK_SIZE) + + # Grid-stride loop over rows + for row_idx in tl.range(pid, n_rows, num_progs): + # Base pointers for this row + dY_row_ptr = dY_ptr + row_idx * dY_row_stride + dX_row_ptr = dX_ptr + row_idx * dX_row_stride + X_row_ptr = X_ptr + row_idx * X_row_stride + RSTD_row_ptr = RSTD_ptr + row_idx * RSTD_row_stride + + # Load rstd for this row + rstd = tl.load(RSTD_row_ptr) + + # First pass: compute sum(m * X) for the correction term + sum_m_X = 0.0 + + for col_block_idx in range(num_col_blocks): + col_start = col_block_idx * BLOCK_SIZE + col_offsets = col_start + offsets + mask = col_offsets < n_cols + + dY_block = tl.load(dY_row_ptr + col_offsets, mask=mask, other=0.0) + X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0) + W_block = tl.load(W_ptr + col_offsets, mask=mask, other=0.0) + + # Convert to fp32 for computation + X_block = X_block.to(tl.float32) + + # Compute m based on casting mode + W_offset = W_block + offset + + if casting_mode == _CASTING_MODE_LLAMA: + m = (dY_block * W_offset).to(tl.float32) + elif casting_mode == _CASTING_MODE_GEMMA: + dY_block = dY_block.to(tl.float32) + m = dY_block * W_offset + else: + m = dY_block * W_offset + + # Accumulate sum(m * X) + sum_m_X += tl.sum(tl.where(mask, m * X_block, 0.0)) + + # Compute the correction factor + correction_factor = -(1.0 / n_cols) * rstd * rstd * sum_m_X + + # Second pass: compute gradients + for col_block_idx in range(num_col_blocks): + col_start = col_block_idx * BLOCK_SIZE + col_offsets = col_start + offsets + mask = col_offsets < n_cols + + dY_block = tl.load(dY_row_ptr + col_offsets, mask=mask, other=0.0) + X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0) + W_block = tl.load(W_ptr + col_offsets, mask=mask, other=0.0) + + X_block = X_block.to(tl.float32) + + # Compute m based on casting mode + W_offset = W_block + offset + + if casting_mode == _CASTING_MODE_LLAMA: + m = (dY_block * W_offset).to(tl.float32) + elif casting_mode == _CASTING_MODE_GEMMA: + dY_block = dY_block.to(tl.float32) + m = dY_block * W_offset + else: + m = dY_block * W_offset + + # Compute dX + dX_block = rstd * m + rstd * correction_factor * X_block + + # Add dS_out gradient if present + if has_dS_out: + dS_out_row_ptr = dS_out_ptr + row_idx * dS_out_row_stride + dS_out_block = tl.load(dS_out_row_ptr + col_offsets, mask=mask, other=0.0) + dX_block += dS_out_block + + # Store dX + tl.store(dX_row_ptr + col_offsets, dX_block.to(X_dtype), mask=mask) + + # Compute dW contribution (accumulate per program) + if casting_mode == _CASTING_MODE_LLAMA: + dW_block = dY_block * (X_block * rstd).to(X_dtype) + else: + dW_block = dY_block * (X_block * rstd) + + # Atomic add to dW_ptr (each program writes to its own row) + dW_row_ptr = dW_ptr + pid * dW_row_stride + + # Load existing dW, add contribution, store back + existing_dW = tl.load(dW_row_ptr + col_offsets, mask=mask, other=0.0) + new_dW = existing_dW + dW_block.to(tl.float32) + tl.store(dW_row_ptr + col_offsets, new_dW, mask=mask) + + +# ----------------------------------------------------------------------------- +# Helper Functions +# ----------------------------------------------------------------------------- + + +def get_optimal_block_size(n_cols, is_forward: bool): + """ + Calculate optimal block size for forward pass using compute_default_tiling_strategy. + + Memory analysis for forward pass (per row): + - Load: X_block, R_block, W_block (3 blocks) + - Store: S_block, Y_block (2 blocks) + - Compute: S_block, Y_block intermediate (2 blocks) + - Total: conservative estimate 8 blocks of memory + + Memory analysis for backward pass (per row): + - Load: dY_block, X_block, W_block, existing_dW (4 blocks) + - Store: dX_block, new_dW (2 blocks) + - Compute: m, dX_block intermediate, dW_block intermediate (3 blocks) + - Additional: dS_out_block if present (1 block) + - Total: conservative estimate 12 blocks of memory + + Args: + n_cols: Number of columns in the tensor + + Returns: + Optimal block size + """ + if n_cols <= 2048: + return triton.next_power_of_2(n_cols) + + memory_multiplier = 8.0 if is_forward else 12.0 + + tile_shapes = compute_default_tiling_strategy( + safety_margin=0.9, + dtype_size=4, + memory_multiplier=memory_multiplier, + shapes=((n_cols,),), + tiling_dims=(0,), + ) + + if tile_shapes and len(tile_shapes) > 0: + block_size = tile_shapes[0][0] + return max(2048, block_size) + else: + return 2048 + + +# ----------------------------------------------------------------------------- +# Forward and Backward Functions +# ----------------------------------------------------------------------------- + + +_str_to_casting_mode = { + "llama": _CASTING_MODE_LLAMA.value, + "gemma": _CASTING_MODE_GEMMA.value, + "none": _CASTING_MODE_NONE.value, +} + + +def fused_add_rms_norm_forward(X, R, W, eps, offset, casting_mode): + if not isinstance(casting_mode, int): + assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}" + casting_mode = _str_to_casting_mode[casting_mode] + else: + assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}" + shape = X.shape + dim = shape[-1] + X = X.view(-1, dim) + R = R.view(-1, dim) + n_rows, n_cols = X.shape + X_DTYPE = torch_dtype_to_triton(X.dtype) + + # Get optimal block size for column processing + BLOCK_SIZE = get_optimal_block_size(n_cols, True) + BLOCK_SIZE_M = 2048 // BLOCK_SIZE + + Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) + S = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) + + # RSTD is always fp32 for Llama/Gemma modes + rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype + RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device) + + # Check constraints + assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension" + + # Grid size limited to NPU core count + num_cores = get_npu_core_count() + grid_size = min(num_cores * 2, n_rows) + + # Choose kernel based on n_cols + if n_cols <= 2048: + # Use no-tiling kernel for small n_cols + _fused_add_rms_norm_forward_kernel_no_tiling[(grid_size,)]( + Y, + Y.stride(0), + S, + S.stride(0), + X, + X.stride(0), + R, + R.stride(0), + W, + RSTD, + RSTD.stride(0), + n_rows, + n_cols, + eps, + offset, + casting_mode, + X_DTYPE, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE, + ) + else: + # Use tiled kernel for large n_cols + _fused_add_rms_norm_forward_kernel_npu[(grid_size,)]( + Y, + Y.stride(0), + S, + S.stride(0), + X, + X.stride(0), + R, + R.stride(0), + W, + RSTD, + RSTD.stride(0), + n_rows, + n_cols, + eps, + offset, + casting_mode, + X_DTYPE, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return Y.view(*shape), S.view(*shape), RSTD, casting_mode + + +def fused_add_rms_norm_backward(dY, dS_out, S, W, RSTD, offset, casting_mode, in_place): + shape = dY.shape + dim = shape[-1] + dY = dY.view(-1, dim) + if dS_out is not None: + dS_out = dS_out.view(-1, dim) + S = S.view(-1, dim) + n_rows, n_cols = dY.shape + + # Get NPU core count for grid size + num_cores = get_npu_core_count() + grid_size = min(num_cores * 2, n_rows) + + # Get optimal block size for backward pass + BLOCK_SIZE = get_optimal_block_size(n_cols, False) + BLOCK_SIZE_M = 2048 // BLOCK_SIZE + + # fp32 for numerical stability + _dW = torch.zeros((grid_size, n_cols), dtype=torch.float32, device=W.device) + + if in_place: + dX = dY + else: + dX = torch.empty_like(dY) + + # Choose kernel based on n_cols + if n_cols <= 2048: + # Use no-tiling kernel for small n_cols + _fused_add_rms_norm_backward_kernel_no_tiling[(grid_size,)]( + dY, + dY.stride(0), + dS_out if dS_out is not None else dY, # Dummy pointer if dS_out is None + dS_out.stride(0) if dS_out is not None else 0, + dX, + dX.stride(0), + S, + S.stride(0), + torch_to_triton_dtype[S.dtype], + W, + RSTD, + RSTD.stride(0), + _dW, + _dW.stride(0), + n_rows, + n_cols, + offset, + casting_mode, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE, + has_dS_out=dS_out is not None, + ) + else: + # Use tiled kernel for large n_cols + _fused_add_rms_norm_backward_kernel_npu[(grid_size,)]( + dY, + dY.stride(0), + dS_out if dS_out is not None else dY, # Dummy pointer if dS_out is None + dS_out.stride(0) if dS_out is not None else 0, + dX, + dX.stride(0), + S, + S.stride(0), + torch_to_triton_dtype[S.dtype], + W, + RSTD, + RSTD.stride(0), + _dW, + _dW.stride(0), + n_rows, + n_cols, + offset, + casting_mode, + BLOCK_SIZE=BLOCK_SIZE, + has_dS_out=dS_out is not None, + ) + + dX = dX.view(*shape) + dW = _dW.sum(dim=0).to(W.dtype) + + return dX, dX, dW # dR is equal to dX + + +# ----------------------------------------------------------------------------- +# Autograd Function +# ----------------------------------------------------------------------------- + + +class LigerFusedAddRMSNormFunction(torch.autograd.Function): + """ + NPU-optimized fused operation for residual addition and RMSNorm. + + This implementation solves two key issues: + 1. UB overflow when n_cols is too large (by using column-wise blocking) + 2. Efficient multi-row processing (by using grid-stride loop with core count limit) + """ + + @staticmethod + @ensure_contiguous + def forward(ctx, X, R, W, eps, offset=0.0, casting_mode="llama", in_place=False): + """ + X: (B, T, H) or (BxT, H) + R: (B, T, H) or (BxT, H) + W: (H,) + """ + Y, S, RSTD, casting_mode = fused_add_rms_norm_forward(X, R, W, eps, offset, casting_mode) + ctx.offset = offset + ctx.casting_mode = casting_mode + ctx.in_place = in_place + ctx.save_for_backward(S, W, RSTD) + return Y, S + + @staticmethod + @ensure_contiguous + def backward(ctx, dY, dS_out): + """ + dY: (B, T, H) or (BxT, H) + dS_out: (B, T, H) or (BxT, H) + """ + S, W, RSTD = ctx.saved_tensors + dX, dR, dW = fused_add_rms_norm_backward( + dY, + dS_out, + S, + W, + RSTD, + ctx.offset, + ctx.casting_mode, + ctx.in_place, + ) + + return dX, dR, dW, None, None, None, None diff --git a/src/liger_kernel/ops/backends/_ascend/ops/fused_linear_jsd.py b/src/liger_kernel/ops/backends/_ascend/ops/fused_linear_jsd.py new file mode 100755 index 0000000000000000000000000000000000000000..10c8fc354217d6c6d511358346a9466670a57fa2 --- /dev/null +++ b/src/liger_kernel/ops/backends/_ascend/ops/fused_linear_jsd.py @@ -0,0 +1,227 @@ +from typing import Optional + +import torch +import triton + +from liger_kernel.ops.backends._ascend.ops.jsd import _jsd_kernel +from liger_kernel.ops.utils import amp_custom_bwd +from liger_kernel.ops.utils import amp_custom_fwd +from liger_kernel.ops.utils import element_mul_kernel +from liger_kernel.ops.utils import get_npu_core_count + +MAX_FUSED_SIZE = 4096 + + +def fused_linear_jsd_forward( + student_input, + student_weight, + teacher_input, + teacher_weight, + shift_labels, + jsd_beta, + ignore_index, + has_label, + temperature, +): + device = student_input.device + dtype = student_input.dtype + + # inputs have shape: BT x H + # materialized activations will have shape: BT x V + # the increase in memory = BT x V + # reduction can be achieved by partitioning the number of tokens BT into smaller chunks. + # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be: + # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor + # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048 + BT, H = student_input.shape + V = student_weight.shape[0] + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + inc_factor = triton.cdiv(V, H) # (V + H - 1) // H + chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor + num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size + + grad_weight = torch.zeros_like(student_weight, device=device) if student_weight.requires_grad else None + grad_input = torch.zeros_like(student_input) + # we use fp32 for loss accumulator + loss_1d = torch.zeros((BT, V), dtype=torch.float32, device=device) + + if has_label: + n_non_ignore = (shift_labels != ignore_index).sum().item() + else: + n_non_ignore = BT + + num_cores = get_npu_core_count() + + for chunk_id in range(num_chunks): + start_idx = chunk_id * chunk_size + end_idx = min((chunk_id + 1) * chunk_size, BT) + + # chunk both inputs, shape: chunk_size x H + student_input_chunk = student_input[start_idx:end_idx] + teacher_input_chunk = teacher_input[start_idx:end_idx] + + # shape: chunk_size x V + # For anything starting from logits to the final JSD loss, we do computation + # in FP32 to avoid losing numerical stability. + student_logits_chunk = (student_input_chunk @ student_weight.t()).to(torch.float32) + teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(torch.float32) + chunk_n_rows = student_logits_chunk.shape[0] + + # unreduced loss + loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size + # log-softmax with temperature + student_logits_chunk = student_logits_chunk / temperature + teacher_logits_chunk = teacher_logits_chunk / temperature + student_prob_chunk = torch.log_softmax(student_logits_chunk, dim=-1) + teacher_prob_chunk = torch.log_softmax(teacher_logits_chunk, dim=-1) + + # ensure _input and target are contiguous + student_prob_chunk = student_prob_chunk.contiguous() + teacher_prob_chunk = teacher_prob_chunk.contiguous() + + # Here we calculate the gradient of prob_chunk in place so we can save memory. + # Grid size is capped at NPU core count; the kernel uses a grid-stride loop + # to process multiple rows per program, consistent with the NPU backend pattern. + grid_size = min(num_cores, chunk_n_rows) + _jsd_kernel[(grid_size,)]( + X_ptr=student_prob_chunk, + X_stride=student_prob_chunk.stride(-2), + Y_ptr=teacher_prob_chunk, + Y_stride=teacher_prob_chunk.stride(-2), + loss_ptr=loss_1d_slice, + loss_stride=loss_1d_slice.stride(-2), + dX_ptr=student_prob_chunk, + dX_stride=student_prob_chunk.stride(-2), + label_ptr=( + shift_labels[start_idx:end_idx] if has_label else torch.empty(1, device=device) + ), # dummy ptr if no label + beta=jsd_beta, + n_non_ignore=n_non_ignore, + ignore_index=ignore_index, + n_rows=chunk_n_rows, + n_cols=V, + BLOCK_SIZE=BLOCK_SIZE, + HAS_LABEL=has_label, + ) + # gradients of prob_chunk in place, shape: chunk_size x V + # gradients of logits_chunk in place, shape: chunk_size x V + student_logits_chunk = ( + student_prob_chunk + - torch.softmax(student_logits_chunk, dim=-1) + * student_prob_chunk.sum(dim=-1, keepdim=True).expand_as(student_prob_chunk) + ) / temperature + # now we traverse back to grad w.r.t. input to `lm_head` and grad + # w.r.t. `lm_head` which should be computed in original dtype + student_logits_chunk = student_logits_chunk.to(dtype) + grad_input[start_idx:end_idx] = student_logits_chunk @ student_weight + + if grad_weight is not None: + grad_weight.add_(student_logits_chunk.t() @ student_input_chunk) + + loss = torch.sum(loss_1d) + return loss, grad_input, grad_weight + + +def fused_linear_jsd_backward(grad_output, grad_input, grad_weight): + # If JSD is the last layer, grad_output is 1.0. Skip the mul to save time + if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)): + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + BT, H = grad_input.shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H)) + + element_mul_kernel[(n_rows,)]( + grad_input, + grad_input.stride(-2), + grad_output, + H, + BLOCK_SIZE=BLOCK_SIZE, + ) + + # handle grad_weight + if grad_weight is not None: + V, H = grad_weight.shape + n_rows = V + + element_mul_kernel[(n_rows,)]( + grad_weight, + grad_weight.stride(-2), + grad_output, + H, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return grad_input, grad_weight + + +class LigerFusedLinearJSDFunction(torch.autograd.Function): + """ + Fusing the last linear layer with generalized JSD + + Handle the forward and backward pass of the final linear layer via JSD by avoiding + the materialization of the large logits tensor. Since JSD is the last layer, we can + compute the gradient at the forward pass. + """ + + @staticmethod + @amp_custom_fwd + def forward( + ctx, + student_input: torch.Tensor, + student_weight: torch.Tensor, + teacher_input: torch.Tensor, + teacher_weight: torch.Tensor, + shift_labels: Optional[torch.Tensor] = None, + jsd_beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + ): + """ + Args: + + student_input (torch.tensor): input of the last projection layer in student model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension. + student_weight (torch.tensor): the last projection layer in student model, with shape (V, H), where V is vocab size + teacher_input (torch.tensor): input of the last projection layer in teacher model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension. + teacher_weight (torch.tensor): the last projection layer in teacher model, with shape (V, H), where V is vocab size + shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1]. + jsd_beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5` + ignore_index (int): the index to ignore. Default: -100 + temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0` + + Returns: + loss (torch.Tensor): generalized JSD + """ + has_label = False + if shift_labels is not None: + assert shift_labels.shape == (teacher_input.shape[0],), ( + f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}" + ) + shift_labels = shift_labels.contiguous() + has_label = True + + loss, grad_input, grad_weight = fused_linear_jsd_forward( + student_input, + student_weight, + teacher_input, + teacher_weight, + shift_labels, + jsd_beta, + ignore_index, + has_label, + temperature, + ) + # downcast to dtype and store for backward + ctx.save_for_backward( + grad_input.detach(), + grad_weight.detach() if grad_weight is not None else None, + ) + return loss + + @staticmethod + @amp_custom_bwd + def backward(ctx, grad_output): + (grad_input, grad_weight) = ctx.saved_tensors + grad_input, grad_weight = fused_linear_jsd_backward(grad_output, grad_input, grad_weight) + return (grad_input, grad_weight, None, None, None, None, None, None) diff --git a/src/liger_kernel/ops/backends/_ascend/ops/geglu.py b/src/liger_kernel/ops/backends/_ascend/ops/geglu.py new file mode 100755 index 0000000000000000000000000000000000000000..123b2b4f262b4e7744961f4b21f96b3a1379678a --- /dev/null +++ b/src/liger_kernel/ops/backends/_ascend/ops/geglu.py @@ -0,0 +1,187 @@ +import torch +import triton +import triton.language as tl + +from triton.language.math import tanh + +from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy +from liger_kernel.ops.utils import ensure_contiguous +from liger_kernel.ops.utils import get_npu_core_count + + +@triton.jit +def _geglu_forward_kernel_flat(a_ptr, b_ptr, c_ptr, total_elements, BLOCK_SIZE: tl.constexpr): + """ + High-performance GEGLU forward kernel using flatten 1D approach. + + Uses grid-stride loop pattern for optimal performance on NPU. + """ + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + + # Grid-Stride Loop + start_idx = pid * BLOCK_SIZE + stride = num_progs * BLOCK_SIZE + + # Constants for GELU tanh approximation + sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi) + gelu_coeff = 0.044715 + + for idx in tl.range(start_idx, total_elements, stride): + offsets = idx + tl.arange(0, BLOCK_SIZE) + mask = offsets < total_elements + + a_val = tl.load(a_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + b_val = tl.load(b_ptr + offsets, mask=mask, other=0.0) + + # tanh approximation form of GELU is computed with: + # 0.5 * a * (1 + tanh(sqrt(2 / pi) * (a + 0.044715 * a^3))) + a_cubed = a_val * a_val * a_val + tanh_arg = sqrt_2_over_pi * (a_val + gelu_coeff * a_cubed) + tanh_result = tanh(tanh_arg) + geglu_a = 0.5 * a_val * (1.0 + tanh_result) + c_row = geglu_a.cast(b_val.dtype) * b_val + tl.store(c_ptr + offsets, c_row, mask=mask) + + +@triton.jit +def _geglu_backward_kernel_flat(dc_ptr, a_ptr, b_ptr, da_ptr, db_ptr, total_elements, BLOCK_SIZE: tl.constexpr): + """ + High-performance GEGLU backward kernel using flatten 1D approach. + + Uses grid-stride loop pattern for optimal performance on NPU. + """ + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + start_idx = pid * BLOCK_SIZE + stride = num_progs * BLOCK_SIZE + + # Constants for GELU tanh approximation + sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi) + gelu_coeff = 0.044715 + + for idx in tl.range(start_idx, total_elements, stride): + offsets = idx + tl.arange(0, BLOCK_SIZE) + mask = offsets < total_elements + + dc = tl.load(dc_ptr + offsets, mask=mask, other=0.0) + a = tl.load(a_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + b = tl.load(b_ptr + offsets, mask=mask, other=0.0) + + # recomputation to save memory + a_cubed = a * a * a + tanh_arg = sqrt_2_over_pi * (a + gelu_coeff * a_cubed) + tanh_result = tanh(tanh_arg) + geglu_a = 0.5 * a * (1 + tanh_result) + geglu_a = geglu_a.to(dc.dtype).to(tl.float32) + + db = dc.cast(tl.float32) * geglu_a + + # Gradient w.r.t. a can be computed with: + # b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2))) + # where z = sqrt(2/pi) * (a + 0.044715 * a^3) + term1 = 0.5 * (1.0 + tanh_result) + tanh_sq = tanh_result * tanh_result + a_sq = a * a + term2 = 0.5 * a * (1.0 - tanh_sq) * (sqrt_2_over_pi * (1.0 + 3.0 * gelu_coeff * a_sq)) + da = dc * b * (term1 + term2) + + tl.store(da_ptr + offsets, da, mask=mask) + tl.store(db_ptr + offsets, db.to(dc.dtype), mask=mask) + + +def get_optimal_block_size(total_elements, is_backward=False): + """ + Calculate optimal Block Size using compute_default_tiling_strategy. + + Args: + total_elements: Total number of elements to process + is_backward: Whether this is for backward pass (requires more memory) + + Returns: + Optimal block size for the kernel + """ + # Memory multiplier based on peak memory usage analysis + if is_backward: + memory_multiplier = 6.0 + else: + memory_multiplier = 3.0 + # Call calculation function + # Treat input as 1D (total_elements,), only tiling on dim 0 + tile_shapes = compute_default_tiling_strategy( + safety_margin=0.9, + dtype_size=4, + memory_multiplier=memory_multiplier, + shapes=((total_elements,),), + tiling_dims=(0,), + ) + + # Parse result + if tile_shapes and len(tile_shapes) > 0: + block_size = tile_shapes[0][0] + return max(256, block_size) + else: + return 2048 + + +def geglu_forward(a, b): + """ + High-performance GEGLU forward pass for NPU using flatten 1D approach. + """ + if not a.is_contiguous(): + a = a.contiguous() + if not b.is_contiguous(): + b = b.contiguous() + + total_elements = a.numel() + c = torch.empty_like(a) + + block_size = get_optimal_block_size(total_elements, is_backward=False) + + num_cores = get_npu_core_count() + grid_size = min(num_cores, (total_elements + block_size - 1) // block_size) + + _geglu_forward_kernel_flat[(grid_size,)](a, b, c, total_elements, BLOCK_SIZE=block_size) + return c + + +def geglu_backward(a, b, dc): + """ + High-performance GEGLU backward pass for NPU using flatten 1D approach. + """ + if not dc.is_contiguous(): + dc = dc.contiguous() + if not a.is_contiguous(): + a = a.contiguous() + if not b.is_contiguous(): + b = b.contiguous() + + total_elements = dc.numel() + grad_a = torch.empty_like(a) + grad_b = torch.empty_like(b) + + block_size = get_optimal_block_size(total_elements, is_backward=True) + + num_cores = get_npu_core_count() + grid_size = min(num_cores, (total_elements + block_size - 1) // block_size) + + _geglu_backward_kernel_flat[(grid_size,)](dc, a, b, grad_a, grad_b, total_elements, BLOCK_SIZE=block_size) + return grad_a, grad_b + + +class LigerGELUMulFunction(torch.autograd.Function): + """High-performance GEGLU function for Ascend NPU.""" + + @staticmethod + @ensure_contiguous + def forward(ctx, a, b): + c = geglu_forward(a, b) + ctx.save_for_backward(a, b) + return c + + @staticmethod + @ensure_contiguous + def backward(ctx, dc): + a, b = ctx.saved_tensors + grad_a, grad_b = geglu_backward(a, b, dc) + return grad_a, grad_b diff --git a/src/liger_kernel/ops/backends/_ascend/ops/group_norm.py b/src/liger_kernel/ops/backends/_ascend/ops/group_norm.py new file mode 100755 index 0000000000000000000000000000000000000000..85d9efc751ebf49ea986895a9800c92ee9b1ee24 --- /dev/null +++ b/src/liger_kernel/ops/backends/_ascend/ops/group_norm.py @@ -0,0 +1,474 @@ +import torch +import triton +import triton.language as tl + +from triton.language.math import rsqrt + +from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy +from liger_kernel.ops.utils import ensure_contiguous +from liger_kernel.ops.utils import get_npu_core_count + +# ----------------------------------------------------------------------------- +# Kernels (2D row/col tiling + persistent programs) +# ----------------------------------------------------------------------------- + + +@triton.jit +def _group_norm_forward_kernel( + Y_ptr, # pointer to output, shape (B, G, hidden_size) + Y_row_stride, # stride of each batch row in Y + Y_col_stride, # stride of each group row in Y + X_ptr, # pointer to input, shape (B, G, hidden_size) + X_row_stride, # stride of each batch row in X + X_col_stride, # stride of each group row in X + Mean_ptr, # pointer to mean output, shape (B, G) + Mean_row_stride, # stride of each batch row in Mean + Mean_col_stride, # stride of each group row in Mean + RSTD_ptr, # pointer to rstd output, shape (B, G) + RSTD_row_stride, # stride of each batch row in RSTD + RSTD_col_stride, # stride of each group row in RSTD + W_ptr, # pointer to affine scale weights, shape (C) + B_ptr, # pointer to affine bias weights, shape (C) + n_rows, # total logical rows = B * G + hidden_size, + channels_per_group, + num_groups, + SINGLE_CHANNEL_TILE: tl.constexpr, + eps, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + + grid_m = tl.cdiv(n_rows, BLOCK_SIZE_M) + num_col_blocks = tl.cdiv(hidden_size, BLOCK_SIZE_N) + hidden_size_per_channel = hidden_size // channels_per_group + hidden_size_inv = 1.0 / hidden_size + row_offsets = tl.arange(0, BLOCK_SIZE_M) + col_offsets_base = tl.arange(0, BLOCK_SIZE_N) + + # Persistent-program loop over row tiles. + for block_m in tl.range(pid, grid_m, num_progs): + row_idx = block_m * BLOCK_SIZE_M + row_offsets + row_mask = row_idx < n_rows + batch_idx = row_idx // num_groups + group_idx = row_idx % num_groups + + row_sum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + row_square_sum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + + # Pass 1: accumulate E[x] and E[x^2] for each row tile. + for cb in range(num_col_blocks): + col_offsets = cb * BLOCK_SIZE_N + col_offsets_base + col_mask = col_offsets < hidden_size + mask = row_mask[:, None] & col_mask[None, :] + X_ptrs = ( + X_ptr + batch_idx[:, None] * X_row_stride + group_idx[:, None] * X_col_stride + col_offsets[None, :] + ) + X_block = tl.load(X_ptrs, mask=mask, other=0.0).to(tl.float32) + row_sum += tl.sum(X_block, axis=1) + row_square_sum += tl.sum(X_block * X_block, axis=1) + + mean = row_sum * hidden_size_inv + var = row_square_sum * hidden_size_inv - mean * mean + rstd = rsqrt(tl.maximum(var, 0.0) + eps) + + mean_ptrs = Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride + rstd_ptrs = RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride + tl.store(mean_ptrs, mean, mask=row_mask) + tl.store(rstd_ptrs, rstd, mask=row_mask) + + # Pass 2: normalize + affine transform. + # SINGLE_CHANNEL_TILE indicates the current column tile maps to one channel, + # so W/B can be loaded once per row and broadcast to the tile. + for cb in range(num_col_blocks): + col_offsets = cb * BLOCK_SIZE_N + col_offsets_base + col_mask = col_offsets < hidden_size + mask = row_mask[:, None] & col_mask[None, :] + X_ptrs = ( + X_ptr + batch_idx[:, None] * X_row_stride + group_idx[:, None] * X_col_stride + col_offsets[None, :] + ) + X_block = tl.load(X_ptrs, mask=mask, other=0.0).to(tl.float32) + if SINGLE_CHANNEL_TILE: + local_channel = (cb * BLOCK_SIZE_N) // hidden_size_per_channel + global_channel = group_idx * channels_per_group + local_channel + W_block = tl.load(W_ptr + global_channel, mask=row_mask, other=0.0).to(tl.float32)[:, None] + B_block = tl.load(B_ptr + global_channel, mask=row_mask, other=0.0).to(tl.float32)[:, None] + else: + local_channel = col_offsets // hidden_size_per_channel + global_channel = group_idx[:, None] * channels_per_group + local_channel[None, :] + W_block = tl.load(W_ptr + global_channel, mask=mask, other=0.0).to(tl.float32) + B_block = tl.load(B_ptr + global_channel, mask=mask, other=0.0).to(tl.float32) + Y_block = (X_block - mean[:, None]) * rstd[:, None] * W_block + B_block + Y_ptrs = ( + Y_ptr + batch_idx[:, None] * Y_row_stride + group_idx[:, None] * Y_col_stride + col_offsets[None, :] + ) + tl.store(Y_ptrs, Y_block, mask=mask) + + +@triton.jit +def _group_norm_backward_kernel( + X_ptr, # pointer to input, shape (B, G, hidden_size) + X_row_stride, # stride of each batch row in X + X_col_stride, # stride of each group row in X + W_ptr, # pointer to affine scale weights, shape (C) + Mean_ptr, # pointer to saved group mean, shape (B, G) + Mean_row_stride, # stride of each batch row in Mean + Mean_col_stride, # stride of each group row in Mean + RSTD_ptr, # pointer to saved reciprocal std, shape (B, G) + DX_ptr, # pointer to input gradients, shape (B, G, hidden_size) + DW_scratch_ptr, # pointer to scratch buffer for dW partial sums, shape (grid, C) + DW_scratch_stride, # row stride for DW_scratch + DB_scratch_ptr, # pointer to scratch buffer for dB partial sums, shape (grid, C) + DB_scratch_stride, # row stride for DB_scratch + DY_ptr, # pointer to upstream gradients, shape (B, G, hidden_size) + DY_row_stride, # stride of each batch row in DY + DY_col_stride, # stride of each group row in DY + n_rows, # total logical rows = B * G + hidden_size, + channels_per_group, + num_groups, + SINGLE_CHANNEL_TILE: tl.constexpr, + COMPUTE_PARAM_GRAD: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + + grid_m = tl.cdiv(n_rows, BLOCK_SIZE_M) + num_col_blocks = tl.cdiv(hidden_size, BLOCK_SIZE_N) + hidden_size_per_channel = hidden_size // channels_per_group + N_inv = 1.0 / hidden_size + row_offsets = tl.arange(0, BLOCK_SIZE_M) + col_offsets_base = tl.arange(0, BLOCK_SIZE_N) + + if COMPUTE_PARAM_GRAD: + DW_scratch_base = DW_scratch_ptr + pid * DW_scratch_stride + DB_scratch_base = DB_scratch_ptr + pid * DB_scratch_stride + + # Persistent-program loop over row tiles. + for block_m in tl.range(pid, grid_m, num_progs): + row_idx = block_m * BLOCK_SIZE_M + row_offsets + row_mask = row_idx < n_rows + batch_idx = row_idx // num_groups + group_idx = row_idx % num_groups + + mean = tl.load( + Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, + mask=row_mask, + other=0.0, + ).to(tl.float32) + rstd = tl.load( + RSTD_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, + mask=row_mask, + other=0.0, + ).to(tl.float32) + + sum_x_hat_wdy = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + sum_wdy = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + + # Pass 1: compute row-wise reduction terms (c1, c2). + for cb in range(num_col_blocks): + col_offsets = cb * BLOCK_SIZE_N + col_offsets_base + col_mask = col_offsets < hidden_size + mask = row_mask[:, None] & col_mask[None, :] + + X_ptrs = ( + X_ptr + batch_idx[:, None] * X_row_stride + group_idx[:, None] * X_col_stride + col_offsets[None, :] + ) + DY_ptrs = ( + DY_ptr + batch_idx[:, None] * DY_row_stride + group_idx[:, None] * DY_col_stride + col_offsets[None, :] + ) + X_block = tl.load(X_ptrs, mask=mask, other=0.0).to(tl.float32) + DY_block = tl.load(DY_ptrs, mask=mask, other=0.0).to(tl.float32) + + if SINGLE_CHANNEL_TILE: + local_channel = (cb * BLOCK_SIZE_N) // hidden_size_per_channel + global_channel = group_idx * channels_per_group + local_channel + W_block = tl.load(W_ptr + global_channel, mask=row_mask, other=0.0).to(tl.float32)[:, None] + else: + local_channel = col_offsets // hidden_size_per_channel + global_channel = group_idx[:, None] * channels_per_group + local_channel[None, :] + W_block = tl.load(W_ptr + global_channel, mask=mask, other=0.0).to(tl.float32) + + x_hat = (X_block - mean[:, None]) * rstd[:, None] + wdy = W_block * DY_block + sum_x_hat_wdy += tl.sum(tl.where(mask, x_hat * wdy, 0.0), axis=1) + sum_wdy += tl.sum(tl.where(mask, wdy, 0.0), axis=1) + + c1 = sum_x_hat_wdy * N_inv + c2 = sum_wdy * N_inv + + # Pass 2: compute DX and optionally accumulate DW/DB. + # COMPUTE_PARAM_GRAD=False is used to skip expensive atomics in cases + # where host-side dense reduction is faster/more stable. + for cb in range(num_col_blocks): + col_offsets = cb * BLOCK_SIZE_N + col_offsets_base + col_mask = col_offsets < hidden_size + mask = row_mask[:, None] & col_mask[None, :] + + X_ptrs = ( + X_ptr + batch_idx[:, None] * X_row_stride + group_idx[:, None] * X_col_stride + col_offsets[None, :] + ) + DY_ptrs = ( + DY_ptr + batch_idx[:, None] * DY_row_stride + group_idx[:, None] * DY_col_stride + col_offsets[None, :] + ) + X_block = tl.load(X_ptrs, mask=mask, other=0.0).to(tl.float32) + DY_block = tl.load(DY_ptrs, mask=mask, other=0.0).to(tl.float32) + + if SINGLE_CHANNEL_TILE: + local_channel = (cb * BLOCK_SIZE_N) // hidden_size_per_channel + global_channel = group_idx * channels_per_group + local_channel + W_block = tl.load(W_ptr + global_channel, mask=row_mask, other=0.0).to(tl.float32)[:, None] + else: + local_channel = col_offsets // hidden_size_per_channel + global_channel = group_idx[:, None] * channels_per_group + local_channel[None, :] + W_block = tl.load(W_ptr + global_channel, mask=mask, other=0.0).to(tl.float32) + + x_hat = (X_block - mean[:, None]) * rstd[:, None] + wdy = W_block * DY_block + DX_block = (wdy - (x_hat * c1[:, None] + c2[:, None])) * rstd[:, None] + + DX_ptrs = ( + DX_ptr + batch_idx[:, None] * X_row_stride + group_idx[:, None] * X_col_stride + col_offsets[None, :] + ) + tl.store(DX_ptrs, DX_block.to(X_ptr.dtype.element_ty), mask=mask) + + if COMPUTE_PARAM_GRAD: + if SINGLE_CHANNEL_TILE: + dW_partial = tl.sum(tl.where(mask, DY_block * x_hat, 0.0), axis=1) + dB_partial = tl.sum(tl.where(mask, DY_block, 0.0), axis=1) + tl.atomic_add(DW_scratch_base + global_channel, dW_partial, mask=row_mask) + tl.atomic_add(DB_scratch_base + global_channel, dB_partial, mask=row_mask) + + +# ----------------------------------------------------------------------------- +# Helper: call compute_default_tiling_strategy +# ----------------------------------------------------------------------------- + + +def get_optimal_block_size(n_rows, dtype_size, BLOCK_SIZE_N, is_backward: bool = False): + # Backward keeps larger live-state than forward in this kernel. + multiplier = 7.0 if is_backward else 6.0 + + # Use fp32-size as conservative UB estimate for tiling. + dtype_size = max(dtype_size, 4) + tile_shapes = compute_default_tiling_strategy( + safety_margin=0.9, + dtype_size=dtype_size, + memory_multiplier=multiplier, + shapes=((n_rows, BLOCK_SIZE_N),), + tiling_dims=(0,), + ) + if tile_shapes and len(tile_shapes) > 0: + return tile_shapes[0][0] + return triton.next_power_of_2(min(128, n_rows)) + + +def group_norm_forward(X, num_channels, num_groups, W, B, eps): + shape = X.shape + batch_size = shape[0] + channels_per_group = num_channels // num_groups + # Reshape X so that the mean / std are computed across each group + X = X.view(batch_size, num_groups, -1).contiguous() + + hidden_size = X.shape[-1] + hidden_size_per_channel = hidden_size // channels_per_group + n_rows = batch_size * num_groups + + BLOCK_SIZE_N = min(128, triton.next_power_of_2(hidden_size)) + BLOCK_SIZE_M = get_optimal_block_size(n_rows, X.element_size(), BLOCK_SIZE_N) + + # Fast path condition: each column tile must lie entirely inside one channel + # segment of length `hidden_size_per_channel`. + # + # Layout of a row: + # | channel0 | channel1 | channel2 | ... + # |----Hc----|----Hc----| + # Hc = hidden_size_per_channel + # + # The kernel processes tiles of shape (BLOCK_SIZE_M, BLOCK_SIZE_N). + # Channel boundaries exist only along the column dimension, because + # each row corresponds to a different (batch, group). + # + # Therefore only BLOCK_SIZE_N matters for whether a tile crosses + # channel boundaries; BLOCK_SIZE_M does not affect channel mapping. + # + # If BLOCK_SIZE_N divides Hc and is <= Hc, each column tile belongs + # to exactly one channel. In that case W/B can be loaded once and + # broadcast across the tile (fast path). + # + # Otherwise a tile may span multiple channels, requiring per-element + # channel index computation and parameter loads (slow path). + single_channel_tile = BLOCK_SIZE_N <= hidden_size_per_channel and hidden_size_per_channel % BLOCK_SIZE_N == 0 + + num_cores = get_npu_core_count() + grid = min(num_cores, triton.cdiv(n_rows, BLOCK_SIZE_M)) + + Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device) + Mean = torch.empty((batch_size, num_groups), dtype=X.dtype, device=X.device) + RSTD = torch.empty((batch_size, num_groups), dtype=X.dtype, device=X.device) + + _group_norm_forward_kernel[(grid,)]( + Y, + Y.stride(0), + Y.stride(1), + X, + X.stride(0), + X.stride(1), + Mean, + Mean.stride(0), + Mean.stride(1), + RSTD, + RSTD.stride(0), + RSTD.stride(1), + W, + B, + n_rows, + hidden_size, + channels_per_group, + num_groups, + SINGLE_CHANNEL_TILE=single_channel_tile, + eps=eps, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + ) + return Y.view(*shape), X.view(*shape), Mean, RSTD + + +def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): + shape = dY.shape + batch_size = shape[0] + channels_per_group = num_channels // num_groups + X_grouped = X.view(batch_size, num_groups, -1) + dY_grouped = dY.view(batch_size, num_groups, -1) + hidden_size = dY_grouped.shape[-1] + hidden_size_per_channel = hidden_size // channels_per_group + n_rows = batch_size * num_groups + + BLOCK_SIZE_N = min(128, triton.next_power_of_2(hidden_size)) + BLOCK_SIZE_M = get_optimal_block_size( + n_rows, + X.element_size(), + BLOCK_SIZE_N, + is_backward=True, + ) + + # Same condition as forward: + # if true, each BLOCK_SIZE_N tile maps cleanly to one channel segment. + single_channel_tile = BLOCK_SIZE_N <= hidden_size_per_channel and hidden_size_per_channel % BLOCK_SIZE_N == 0 + + num_cores = get_npu_core_count() + grid = min(num_cores, triton.cdiv(n_rows, BLOCK_SIZE_M)) + # For non-single-channel tiles, per-element atomic updates are costly. + # In that case, kernel computes DX only and DW/DB are reduced on host side. + compute_param_grad = single_channel_tile + + DX = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device) + if compute_param_grad: + DW_scratch = torch.zeros((grid, num_channels), dtype=torch.float32, device=W.device) + DB_scratch = torch.zeros((grid, num_channels), dtype=torch.float32, device=W.device) + else: + # Not used when COMPUTE_PARAM_GRAD=False. + # Intentionally set to None to enforce fail-fast behavior if accidentally accessed. + DW_scratch = None + DB_scratch = None + + _group_norm_backward_kernel[(grid,)]( + X_grouped, + X_grouped.stride(0), + X_grouped.stride(1), + W, + Mean, + Mean.stride(0), + Mean.stride(1), + RSTD, + DX, + DW_scratch, + 0 if not compute_param_grad else DW_scratch.stride(0), + DB_scratch, + 0 if not compute_param_grad else DB_scratch.stride(0), + dY_grouped, + dY_grouped.stride(0), + dY_grouped.stride(1), + n_rows, + hidden_size, + channels_per_group, + num_groups, + SINGLE_CHANNEL_TILE=single_channel_tile, + COMPUTE_PARAM_GRAD=compute_param_grad, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + ) + + # Precision note: + # - In-kernel atomic_add on floating-point values is order-dependent under parallel + # scheduling (non-associative summation), which can introduce run-to-run numerical + # differences in DW/DB for contention-heavy shapes. + # - Host-side dense reduction provides a more stable accumulation pattern for these + # difficult layouts. + if compute_param_grad: + DW = DW_scratch.sum(dim=0).to(W.dtype) + DB = DB_scratch.sum(dim=0).to(W.dtype) + else: + # Fallback path to avoid severe atomic contention when SINGLE_CHANNEL_TILE=False. + # Layout: [B, G, hidden_size] -> [B, G, C_per_G, hidden_per_channel] + X4 = X_grouped.reshape(batch_size, num_groups, channels_per_group, hidden_size_per_channel).to(torch.float32) + dY4 = dY_grouped.reshape(batch_size, num_groups, channels_per_group, hidden_size_per_channel).to(torch.float32) + mean4 = Mean.reshape(batch_size, num_groups, 1, 1).to(torch.float32) + rstd4 = RSTD.reshape(batch_size, num_groups, 1, 1).to(torch.float32) + + x_hat4 = (X4 - mean4) * rstd4 + DW = (dY4 * x_hat4).sum(dim=(0, 3)).reshape(-1).to(W.dtype) + DB = dY4.sum(dim=(0, 3)).reshape(-1).to(W.dtype) + + return DX.view(*shape), DW, DB + + +class LigerGroupNormFunction(torch.autograd.Function): + """ + Group Normalization autograd function for Ascend NPU. + + Forward computes, for each sample/group: + y = (x - mean) * rstd * weight + bias + where: + mean = E[x], rstd = 1 / sqrt(Var[x] + eps) + + The kernel uses row/column tiling with persistent programs. Backward computes + input gradients in Triton and computes parameter gradients either via Triton + atomics (fast path) or host-side dense reduction (fallback path), depending + on the tile/channel layout. + """ + + @staticmethod + @ensure_contiguous + def forward( + ctx, + X, + affine_scaling_weight, + affine_shifting_bias, + num_channels, + num_groups, + eps, + ): + Y, X, Mean, RSTD = group_norm_forward( + X, + num_channels, + num_groups, + affine_scaling_weight, + affine_shifting_bias, + eps, + ) + ctx.num_channels = num_channels + ctx.num_groups = num_groups + ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD) + return Y + + @staticmethod + @ensure_contiguous + def backward(ctx, dY): + X, W, B, Mean, RSTD = ctx.saved_tensors + DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups) + return DX, DW, DB, None, None, None diff --git a/src/liger_kernel/ops/backends/_ascend/ops/grpo_loss.py b/src/liger_kernel/ops/backends/_ascend/ops/grpo_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..b430f4c3b37be888bf1a56336e4c645ff1293413 --- /dev/null +++ b/src/liger_kernel/ops/backends/_ascend/ops/grpo_loss.py @@ -0,0 +1,1006 @@ +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy +from liger_kernel.ops.utils import ensure_contiguous +from liger_kernel.ops.utils import get_npu_core_count + +# Loss type mapping for Triton constexpr branching +# GRPO/DAPO/BNPO/DR_GRPO share identical per-token loss computation (standard PPO clipping) +_TYPE_GRPO: tl.constexpr = tl.constexpr(0) +_TYPE_CISPO: tl.constexpr = tl.constexpr(1) +_TYPE_SAPO: tl.constexpr = tl.constexpr(2) + +_str_to_loss_type = { + "grpo": _TYPE_GRPO.value, + "dapo": _TYPE_GRPO.value, + "bnpo": _TYPE_GRPO.value, + "dr_grpo": _TYPE_GRPO.value, + "luspo": _TYPE_GRPO.value, + "cispo": _TYPE_CISPO.value, + "sapo": _TYPE_SAPO.value, +} + + +def calculate_tile_count_2d(batch_size, seq_len, num_cores): + """Compute optimal grid configuration for parallel processing.""" + grid_batch = batch_size + cores_per_sample = min(seq_len, num_cores // batch_size) + cores_per_sample = max(1, cores_per_sample) + grid_seq = cores_per_sample + total = grid_batch * grid_seq + if total > num_cores: + grid_seq = max(1, num_cores // grid_batch) + return (grid_batch, grid_seq) + + +def compute_block_size_softmax(seq_vocab_size): + """Determine optimal block size for selective log-softmax kernel.""" + multiplier = 6.0 + tile_shapes = compute_default_tiling_strategy( + safety_margin=0.9, dtype_size=4, memory_multiplier=multiplier, shapes=((seq_vocab_size,),), tiling_dims=(0,) + ) + if tile_shapes and len(tile_shapes) > 0: + return tile_shapes[0][0] + return 2048 + + +def compute_block_size_forward(seq_vocab_size): + """Determine optimal block size for forward pass kernel.""" + multiplier = 10.0 + tile_shapes = compute_default_tiling_strategy( + safety_margin=0.9, dtype_size=4, memory_multiplier=multiplier, shapes=((seq_vocab_size,),), tiling_dims=(0,) + ) + if tile_shapes and len(tile_shapes) > 0: + return tile_shapes[0][0] + return 2048 + + +def compute_block_size_backward(seq_vocab_size): + """Determine optimal block size for backward pass kernel.""" + multiplier = 12.0 + tile_shapes = compute_default_tiling_strategy( + safety_margin=0.9, dtype_size=4, memory_multiplier=multiplier, shapes=((seq_vocab_size,),), tiling_dims=(0,) + ) + if tile_shapes and len(tile_shapes) > 0: + return tile_shapes[0][0] + return 2048 + + +@triton.jit +def _selective_log_softmax_kernel( + LOGITS, + INPUT_IDS, + LOG_P, + MASK, + TEMPERATURE, + stride_input_ids_b, + L: tl.constexpr, + N: tl.constexpr, + BLOCK_N: tl.constexpr = 2048, +): + pid_b = tl.program_id(0) + pid_l = tl.program_id(1) + num_progs_l = tl.num_programs(1) + + batch_start = pid_b * L + batch_end = batch_start + L + start_token = batch_start + pid_l + stride = num_progs_l + + for token_idx in tl.range(start_token, batch_end, stride): + off_b = token_idx // L + off_l = token_idx % L + + should_process = 1 + if MASK is not None: + MASK_local = MASK + off_b * stride_input_ids_b + off_l + not_skip = tl.load(MASK_local) + should_process = not_skip + + if should_process != 0: + LOGITS_local = LOGITS + off_b * (L + 1) * N + off_l * N + INPUT_IDS_local = INPUT_IDS + off_b * stride_input_ids_b + off_l + LOG_P_local = LOG_P + token_idx + + m_i = float("-inf") + l_i = 0.0 + for start in range(0, N, BLOCK_N): + cols = start + tl.arange(0, BLOCK_N) + logits = tl.load(LOGITS_local + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE + new_m_i = tl.maximum(m_i, tl.max(logits)) + alpha = tl.exp(m_i - new_m_i) + l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i)) + m_i = new_m_i + lse = m_i + tl.log(l_i) + + ids = tl.load(INPUT_IDS_local) + x = tl.load(LOGITS_local + ids).to(tl.float32) / TEMPERATURE + logp = x - lse + tl.store(LOG_P_local, logp) + + +@triton.jit +def _grpo_loss_fwd_kernel( + LOGITS, + OLD_LOGP, + REF_LOGP, + INPUT_IDS, + COMPLETION_MASK, + ADVANTAGES, + VLLM_IS_RATIO, + VLLM_IS_RATIO_STRIDE, + LOSS, + LSE, + KL, + IS_CLIPPED, + TEMPERATURE, + BETA: tl.constexpr, + EPS_LOW, + EPS_HIGH, + LOSS_TYPE: tl.constexpr, + SAPO_TEMP_POS, + SAPO_TEMP_NEG, + DELTA, + USE_BIAS_CORRECTION_KL: tl.constexpr, + L: tl.constexpr, + N: tl.constexpr, + BLOCK_N: tl.constexpr = 2048, +): + pid_b = tl.program_id(0) + pid_l = tl.program_id(1) + num_progs_l = tl.num_programs(1) + + batch_start = pid_b * L + batch_end = batch_start + L + start_token = batch_start + pid_l + stride = num_progs_l + + for token_idx in tl.range(start_token, batch_end, stride): + off_b = token_idx // L + off_l = token_idx % L + + should_process = 1 + if COMPLETION_MASK is not None: + COMPLETION_MASK_local = COMPLETION_MASK + off_b * L + off_l + not_skip = tl.load(COMPLETION_MASK_local) + should_process = not_skip + + if should_process != 0: + LOGITS_local = LOGITS + off_b * (L + 1) * N + off_l * N + INPUT_IDS_local = INPUT_IDS + off_b * L + off_l + ADVANTAGES_local = ADVANTAGES + off_b + LOSS_local = LOSS + token_idx + LSE_local = LSE + token_idx + IS_CLIPPED_local = IS_CLIPPED + token_idx + + m_i = float("-inf") + l_i = 0.0 + for start in range(0, N, BLOCK_N): + cols = start + tl.arange(0, BLOCK_N) + logits = tl.load(LOGITS_local + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE + new_m_i = tl.maximum(m_i, tl.max(logits)) + alpha = tl.exp(m_i - new_m_i) + l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i)) + m_i = new_m_i + lse = m_i + tl.log(l_i) + + idx = tl.load(INPUT_IDS_local) + x = tl.load(LOGITS_local + idx).to(tl.float32) / TEMPERATURE + logp = x - lse + if OLD_LOGP is None: + old_logp = logp + else: + OLD_LOGP_local = OLD_LOGP + token_idx + old_logp = tl.load(OLD_LOGP_local).to(tl.float32) + coef_1 = tl.exp(logp - old_logp) + advantage = tl.load(ADVANTAGES_local).to(tl.float32) + + if LOSS_TYPE == 0: # GRPO/DAPO/BNPO/DR_GRPO + coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH) + is_low_clipped = (coef_1 < 1 - EPS_LOW) & (advantage < 0) + is_high_clipped = (coef_1 > 1 + EPS_HIGH) & (advantage > 0) + is_clipped = is_low_clipped | is_high_clipped + if DELTA != 0.0: + coef_1 = tl.minimum(coef_1, DELTA) + per_token_loss1 = coef_1 * advantage + per_token_loss2 = coef_2 * advantage + per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2) + + elif LOSS_TYPE == 1: # CISPO + coef_2 = tl.minimum(coef_1, EPS_HIGH) + per_token_loss = -coef_2 * advantage * logp + is_clipped = (coef_1 > EPS_HIGH) & (advantage > 0) + + elif LOSS_TYPE == 2: # SAPO + temperature = tl.where(advantage > 0, SAPO_TEMP_POS, SAPO_TEMP_NEG) + sigmoid_input = temperature * (coef_1 - 1.0) + sapo_coef = tl.sigmoid(sigmoid_input) * 4.0 / temperature + per_token_loss = -sapo_coef * advantage + is_clipped = 0.0 + + if VLLM_IS_RATIO is not None: + vllm_is_ratio = tl.load(VLLM_IS_RATIO + off_b * VLLM_IS_RATIO_STRIDE + off_l % VLLM_IS_RATIO_STRIDE).to( + tl.float32 + ) + per_token_loss = per_token_loss * vllm_is_ratio + + if BETA != 0.0: + REF_LOGP_local = REF_LOGP + token_idx + KL_local = KL + token_idx + ref_logp = tl.load(REF_LOGP_local).to(tl.float32) + kl = tl.exp(ref_logp - logp) - (ref_logp - logp) - 1 + if USE_BIAS_CORRECTION_KL: + kl = kl * tl.exp(logp - old_logp) + per_token_loss += BETA * kl + tl.store(KL_local, kl) + + tl.store(LOSS_local, per_token_loss) + tl.store(LSE_local, lse) + tl.store(IS_CLIPPED_local, is_clipped) + + +@triton.jit +def _grpo_loss_fwd_kernel_seq( + LOGITS, + OLD_LOGP, + REF_LOGP, + INPUT_IDS, + COMPLETION_MASK, + ADVANTAGES, + COEF_1, + COEF_2, + IS_CLIPPED_SEQ, + VLLM_IS_RATIO, + VLLM_IS_RATIO_STRIDE, + LOSS, + LSE, + KL, + IS_CLIPPED, + TEMPERATURE, + BETA: tl.constexpr, + USE_BIAS_CORRECTION_KL: tl.constexpr, + L: tl.constexpr, + N: tl.constexpr, + BLOCK_N: tl.constexpr = 2048, +): + pid_b = tl.program_id(0) + pid_l = tl.program_id(1) + num_progs_l = tl.num_programs(1) + + batch_start = pid_b * L + batch_end = batch_start + L + start_token = batch_start + pid_l + stride = num_progs_l + + for token_idx in tl.range(start_token, batch_end, stride): + off_b = token_idx // L + off_l = token_idx % L + + should_process = 1 + if COMPLETION_MASK is not None: + COMPLETION_MASK_local = COMPLETION_MASK + off_b * L + off_l + not_skip = tl.load(COMPLETION_MASK_local) + should_process = not_skip + + if should_process != 0: + LOGITS_local = LOGITS + off_b * (L + 1) * N + off_l * N + INPUT_IDS_local = INPUT_IDS + off_b * L + off_l + ADVANTAGES_local = ADVANTAGES + off_b + COEF_1_local = COEF_1 + off_b + COEF_2_local = COEF_2 + off_b + IS_CLIPPED_SEQ_local = IS_CLIPPED_SEQ + off_b + LOSS_local = LOSS + token_idx + LSE_local = LSE + token_idx + IS_CLIPPED_local = IS_CLIPPED + token_idx + + m_i = float("-inf") + l_i = 0.0 + for start in range(0, N, BLOCK_N): + cols = start + tl.arange(0, BLOCK_N) + logits = tl.load(LOGITS_local + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE + new_m_i = tl.maximum(m_i, tl.max(logits)) + alpha = tl.exp(m_i - new_m_i) + l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i)) + m_i = new_m_i + lse = m_i + tl.log(l_i) + + idx = tl.load(INPUT_IDS_local) + x = tl.load(LOGITS_local + idx).to(tl.float32) / TEMPERATURE + logp = x - lse + + coef_1 = tl.load(COEF_1_local).to(tl.float32) + coef_2 = tl.load(COEF_2_local).to(tl.float32) + is_clipped_seq = tl.load(IS_CLIPPED_SEQ_local) + + advantage = tl.load(ADVANTAGES_local).to(tl.float32) + per_token_loss1 = coef_1 * advantage + per_token_loss2 = coef_2 * advantage + per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2) + + if VLLM_IS_RATIO is not None: + vllm_is_ratio = tl.load(VLLM_IS_RATIO + off_b * VLLM_IS_RATIO_STRIDE + off_l % VLLM_IS_RATIO_STRIDE).to( + tl.float32 + ) + per_token_loss = per_token_loss * vllm_is_ratio + + if BETA != 0.0: + REF_LOGP_local = REF_LOGP + token_idx + KL_local = KL + token_idx + ref_logp = tl.load(REF_LOGP_local).to(tl.float32) + kl = tl.exp(ref_logp - logp) - (ref_logp - logp) - 1 + if USE_BIAS_CORRECTION_KL: + if OLD_LOGP is None: + old_logp = logp + else: + old_logp = tl.load(OLD_LOGP + token_idx).to(tl.float32) + kl = kl * tl.exp(logp - old_logp) + per_token_loss += BETA * kl + tl.store(KL_local, kl) + + tl.store(LOSS_local, per_token_loss) + tl.store(LSE_local, lse) + tl.store(IS_CLIPPED_local, is_clipped_seq) + + +@triton.jit +def _grpo_loss_bwd_kernel_seq( + DLOSS, + DLOSS_SUM, + DLOGITS, + LOGITS, + OLD_LOGP, + REF_LOGP, + INPUT_IDS, + ADVANTAGES, + COMPLETION_MASK, + LSE, + COEF_1, + SEQ_LEN, + TEMPERATURE, + BETA: tl.constexpr, + USE_BIAS_CORRECTION_KL: tl.constexpr, + EPS_LOW, + EPS_HIGH, + DELTA, + loss_stride0, + loss_stride1, + L: tl.constexpr, + N: tl.constexpr, + BLOCK_N: tl.constexpr = 2048, +): + pid_b = tl.program_id(0) + pid_l = tl.program_id(1) + num_progs_l = tl.num_programs(1) + + batch_start = pid_b * L + batch_end = batch_start + L + start_token = batch_start + pid_l + stride = num_progs_l + + for token_idx in tl.range(start_token, batch_end, stride): + off_b = token_idx // L + off_l = token_idx % L + + DLOGITS_local = DLOGITS + off_b * (L + 1) * N + off_l * N + + should_process = 1 + if COMPLETION_MASK is not None: + COMPLETION_MASK_local = COMPLETION_MASK + off_b * L + off_l + not_skip = tl.load(COMPLETION_MASK_local) + should_process = not_skip + + if should_process == 0: + for start in range(0, N, BLOCK_N): + cols = tl.arange(0, BLOCK_N) + start + tl.store(DLOGITS_local + cols, 0.0, mask=cols < N) + else: + LOGITS_local = LOGITS + off_b * (L + 1) * N + off_l * N + DLOSS_local = DLOSS + off_b * loss_stride0 + off_l * loss_stride1 + DLOSS_SUM_local = DLOSS_SUM + off_b + INPUT_IDS_local = INPUT_IDS + off_b * L + off_l + ADVANTAGES_local = ADVANTAGES + off_b + LSE_local = LSE + token_idx + COEF_1_local = COEF_1 + off_b + SEQ_LEN_local = SEQ_LEN + off_b + + dloss = tl.load(DLOSS_local).to(tl.float32) + dloss_sum = tl.load(DLOSS_SUM_local).to(tl.float32) + lse = tl.load(LSE_local).to(tl.float32) + coef_1 = tl.load(COEF_1_local).to(tl.float32) + seq_len = tl.load(SEQ_LEN_local).to(tl.float32) + + idx = tl.load(INPUT_IDS_local) + x = tl.load(LOGITS_local + idx).to(tl.float32) / TEMPERATURE + logp = x - lse + + advantage = tl.load(ADVANTAGES_local).to(tl.float32) + coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH) + if DELTA != 0.0: + coef_1_for_loss = tl.minimum(coef_1, DELTA) + else: + coef_1_for_loss = coef_1 + per_token_loss1 = coef_1_for_loss * advantage + per_token_loss2 = coef_2 * advantage + is_unclipped = per_token_loss2 >= per_token_loss1 + + dlogp = -coef_1 * advantage / seq_len * is_unclipped * dloss_sum + if DELTA != 0.0: + dlogp = dlogp * (coef_1 <= DELTA) + + if BETA != 0.0: + REF_LOGP_local = REF_LOGP + token_idx + ref_logp = tl.load(REF_LOGP_local).to(tl.float32) + if USE_BIAS_CORRECTION_KL: + if OLD_LOGP is None: + old_logp = logp + else: + old_logp = tl.load(OLD_LOGP + token_idx).to(tl.float32) + token_coef_1 = tl.exp(logp - old_logp) + dlogp += BETA * token_coef_1 * (logp - ref_logp) * dloss + else: + dlogp += BETA * (1 - tl.exp(ref_logp - logp)) * dloss + + dlogp = dlogp / TEMPERATURE + for start_n in tl.range(0, N, BLOCK_N): + cols = start_n + tl.arange(0, BLOCK_N) + logits = tl.load(LOGITS_local + cols, mask=cols < N, other=-float("inf")).to(tl.float32) / TEMPERATURE + probs = tl.exp(logits - lse) + dlogits = tl.where(cols == idx, 1 - probs, -probs) * dlogp + tl.store(DLOGITS_local + cols, dlogits, mask=cols < N) + + +@triton.jit +def _grpo_loss_bwd_kernel( + DLOSS, + DLOGITS, + LOGITS, + OLD_LOGP, + REF_LOGP, + INPUT_IDS, + ADVANTAGES, + COMPLETION_MASK, + LSE, + VLLM_IS_RATIO, + VLLM_IS_RATIO_STRIDE, + TEMPERATURE, + BETA: tl.constexpr, + EPS_LOW, + EPS_HIGH, + LOSS_TYPE: tl.constexpr, + SAPO_TEMP_POS, + SAPO_TEMP_NEG, + DELTA, + USE_BIAS_CORRECTION_KL: tl.constexpr, + loss_stride0, + loss_stride1, + L: tl.constexpr, + N: tl.constexpr, + BLOCK_N: tl.constexpr = 2048, +): + pid_b = tl.program_id(0) + pid_l = tl.program_id(1) + num_progs_l = tl.num_programs(1) + + batch_start = pid_b * L + batch_end = batch_start + L + start_token = batch_start + pid_l + stride = num_progs_l + + for token_idx in tl.range(start_token, batch_end, stride): + off_b = token_idx // L + off_l = token_idx % L + + DLOGITS_local = DLOGITS + off_b * (L + 1) * N + off_l * N + + should_process = 1 + if COMPLETION_MASK is not None: + COMPLETION_MASK_local = COMPLETION_MASK + off_b * L + off_l + not_skip = tl.load(COMPLETION_MASK_local) + should_process = not_skip + + if should_process == 0: + for start in range(0, N, BLOCK_N): + cols = tl.arange(0, BLOCK_N) + start + tl.store(DLOGITS_local + cols, 0.0, mask=cols < N) + else: + LOGITS_local = LOGITS + off_b * (L + 1) * N + off_l * N + DLOSS_local = DLOSS + off_b * loss_stride0 + off_l * loss_stride1 + INPUT_IDS_local = INPUT_IDS + off_b * L + off_l + ADVANTAGES_local = ADVANTAGES + off_b + LSE_local = LSE + token_idx + + dloss = tl.load(DLOSS_local).to(tl.float32) + lse = tl.load(LSE_local).to(tl.float32) + + idx = tl.load(INPUT_IDS_local) + x = tl.load(LOGITS_local + idx).to(tl.float32) / TEMPERATURE + logp = x - lse + if OLD_LOGP is None: + old_logp = logp + else: + OLD_LOGP_local = OLD_LOGP + token_idx + old_logp = tl.load(OLD_LOGP_local).to(tl.float32) + coef_1 = tl.exp(logp - old_logp) + advantage = tl.load(ADVANTAGES_local).to(tl.float32) + + if LOSS_TYPE == 0: # GRPO/DAPO/BNPO/DR_GRPO + coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH) + if DELTA != 0.0: + coef_1_for_loss = tl.minimum(coef_1, DELTA) + else: + coef_1_for_loss = coef_1 + per_token_loss1 = coef_1_for_loss * advantage + per_token_loss2 = coef_2 * advantage + mask = per_token_loss2 >= per_token_loss1 + dlogp = -coef_1 * advantage * mask + if DELTA != 0.0: + dlogp = dlogp * (coef_1 <= DELTA) + + elif LOSS_TYPE == 1: # CISPO + coef_2 = tl.minimum(coef_1, EPS_HIGH) + dlogp = -coef_2 * advantage + + elif LOSS_TYPE == 2: # SAPO + temperature = tl.where(advantage > 0, SAPO_TEMP_POS, SAPO_TEMP_NEG) + sigmoid_input = temperature * (coef_1 - 1.0) + sigmoid_val = tl.sigmoid(sigmoid_input) + d_sapo_d_coef1 = 4.0 * sigmoid_val * (1.0 - sigmoid_val) + dlogp = -advantage * d_sapo_d_coef1 * coef_1 + + if VLLM_IS_RATIO is not None: + vllm_is_ratio = tl.load(VLLM_IS_RATIO + off_b * VLLM_IS_RATIO_STRIDE + off_l % VLLM_IS_RATIO_STRIDE).to( + tl.float32 + ) + dlogp = dlogp * vllm_is_ratio + + if BETA != 0.0: + REF_LOGP_local = REF_LOGP + token_idx + ref_logp = tl.load(REF_LOGP_local).to(tl.float32) + if USE_BIAS_CORRECTION_KL: + dlogp += BETA * coef_1 * (logp - ref_logp) + else: + dlogp += BETA * (1 - tl.exp(ref_logp - logp)) + + dlogp = dlogp * dloss / TEMPERATURE + tl.debug_barrier() + for start_n in tl.range(0, N, BLOCK_N): + cols = start_n + tl.arange(0, BLOCK_N) + logits = tl.load(LOGITS_local + cols, mask=cols < N, other=-float("inf")).to(tl.float32) / TEMPERATURE + probs = tl.exp(logits - lse) + dlogits = tl.where(cols == idx, 1 - probs, -probs) * dlogp + tl.store(DLOGITS_local + cols, dlogits, mask=cols < N) + + +@torch.no_grad +def fused_selective_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 0.9, mask=None): + """Compute log probabilities for specific token IDs with selective masking.""" + assert logits.is_contiguous() + B, L_ADD_1, N = logits.shape + L = L_ADD_1 - 1 + input_ids = input_ids[:, -L:] + if mask is not None: + mask = mask[:, -L:] + log_p = torch.zeros(B, L, dtype=torch.float32, device=logits.device) + + block_n = compute_block_size_softmax(N) + num_cores = get_npu_core_count() + grid = calculate_tile_count_2d(B, L, num_cores) + _selective_log_softmax_kernel[grid]( + logits, + input_ids, + log_p, + mask, + temperature, + input_ids.stride(0), + L, + N, + BLOCK_N=block_n, + ) + return log_p + + +def compute_distribution_normalizer(completion_mask): + """Calculate global active token count for distributed loss normalization.""" + normalizer = completion_mask.to(torch.float32).sum() + world_size = 1 + if torch.distributed.is_available() and torch.distributed.is_initialized(): + normalizer = normalizer.clone() + torch.distributed.all_reduce(normalizer, op=torch.distributed.ReduceOp.SUM) + world_size = torch.distributed.get_world_size() + normalizer = normalizer / world_size + return torch.clamp(normalizer, min=1.0) + + +def reduce_loss(per_token_loss, mask, loss_type, max_completion_length, batch_size, seq_len): + """Apply reduction strategy based on specified loss type.""" + if loss_type == "grpo" or loss_type == "sapo": + return ((per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean() + elif loss_type == "bnpo": + return (per_token_loss * mask).sum() / mask.sum().clamp(min=1.0) + elif loss_type == "dr_grpo": + max_len = max_completion_length if max_completion_length is not None else seq_len + return (per_token_loss * mask).sum() / (batch_size * max_len) + elif loss_type == "dapo" or loss_type == "cispo": + return (per_token_loss * mask).sum() / compute_distribution_normalizer(mask) + elif loss_type == "luspo": + return (per_token_loss * mask.sum(-1, keepdim=True)).mean() + raise ValueError(f"Unknown loss_type: {loss_type}. Expected one of: grpo, bnpo, dr_grpo, dapo, cispo, sapo, luspo") + + +def grpo_loss_forward_triton( + ctx, + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + inplace, + loss_type="grpo", + max_completion_length=None, + reduce=True, + importance_sampling_level="token", + sapo_temperature_pos=1.0, + sapo_temperature_neg=1.05, + vllm_is_ratio=None, + delta=None, + use_bias_correction_kl=False, +): + """Forward pass computation for GRPO loss.""" + assert logits.is_contiguous() and completion_ids.is_contiguous() + assert old_logp is None or old_logp.is_contiguous() + assert (ref_logp is not None and ref_logp.is_contiguous()) if beta != 0.0 else True + assert importance_sampling_level in ("token", "sequence"), ( + f"importance_sampling_level must be 'token' or 'sequence', got {importance_sampling_level}" + ) + + if loss_type not in _str_to_loss_type: + raise ValueError(f"Unknown loss_type '{loss_type}'. Supported types: {list(_str_to_loss_type.keys())}") + + if delta is not None and loss_type in ("cispo", "sapo"): + raise ValueError(f"delta (two-sided clipping) is not supported for loss_type='{loss_type}'.") + + delta_val = 0.0 if delta is None else float(delta) + + if importance_sampling_level == "sequence" and loss_type in ("cispo", "sapo"): + raise ValueError( + f"Sequence-level importance sampling is not supported for loss_type='{loss_type}'. " + f"Use importance_sampling_level='token' instead." + ) + + if loss_type == "sapo": + if sapo_temperature_pos <= 0: + raise ValueError(f"sapo_temperature_pos must be positive, got {sapo_temperature_pos}") + if sapo_temperature_neg <= 0: + raise ValueError(f"sapo_temperature_neg must be positive, got {sapo_temperature_neg}") + + loss_type_int = _str_to_loss_type[loss_type] + + B, L_ADD_1, N = logits.shape + L = L_ADD_1 - 1 + + if completion_mask is not None: + assert completion_mask.is_contiguous() + + mask = completion_mask.float() if completion_mask is not None else torch.ones(B, L, device=logits.device) + + vllm_is_ratio_ptr = None + vllm_is_ratio_stride = L + if vllm_is_ratio is not None: + assert vllm_is_ratio.dim() in (1, 2), ( + f"vllm_is_ratio must be 1D (B,) or 2D (B, L) / (B, 1), got {vllm_is_ratio.dim()}D" + ) + if vllm_is_ratio.dim() == 2: + assert vllm_is_ratio.shape[0] == B and vllm_is_ratio.shape[1] in (1, L), ( + f"vllm_is_ratio shape must be ({B}, 1) or ({B}, {L}), got {tuple(vllm_is_ratio.shape)}" + ) + else: + assert vllm_is_ratio.shape[0] == B, f"vllm_is_ratio shape must be ({B},), got {tuple(vllm_is_ratio.shape)}" + vllm_is_ratio = vllm_is_ratio.contiguous() + vllm_is_ratio_ptr = vllm_is_ratio + vllm_is_ratio_stride = vllm_is_ratio.shape[1] if vllm_is_ratio.dim() > 1 else 1 + + loss = torch.zeros(B, L, device=logits.device, dtype=torch.float32) + lse = torch.zeros_like(loss) + is_clipped = torch.zeros_like(loss) + kl = torch.zeros_like(loss) if beta != 0.0 else None + + block_n = compute_block_size_forward(N) + num_cores = get_npu_core_count() + grid = calculate_tile_count_2d(B, L, num_cores) + + if importance_sampling_level == "sequence": + per_token_logps = fused_selective_log_softmax(logits, completion_ids, temperature, completion_mask) + + if old_logp is None: + log_ratio = torch.zeros_like(per_token_logps) + else: + log_ratio = per_token_logps - old_logp + + seq_lens = mask.sum(-1).clamp(min=1.0) + seq_log_importance = (log_ratio * mask).sum(-1) / seq_lens + coef_1 = torch.exp(seq_log_importance) + coef_2 = torch.clamp(coef_1, 1 - eps_low, 1 + eps_high) + + is_clipped_seq = ((coef_1 < 1 - eps_low) & (advantages < 0)) | ((coef_1 > 1 + eps_high) & (advantages > 0)) + is_clipped_seq = is_clipped_seq.float() + + if delta is not None: + coef_1_for_loss = torch.clamp(coef_1, max=delta) + else: + coef_1_for_loss = coef_1 + + _grpo_loss_fwd_kernel_seq[grid]( + logits, + old_logp, + ref_logp, + completion_ids, + completion_mask, + advantages, + coef_1_for_loss.contiguous(), + coef_2.contiguous(), + is_clipped_seq.contiguous(), + vllm_is_ratio_ptr, + vllm_is_ratio_stride, + loss, + lse, + kl, + is_clipped, + temperature, + beta, + use_bias_correction_kl, + L, + N, + BLOCK_N=block_n, + ) + + ctx.save_for_backward( + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + lse, + mask, + coef_1, + seq_lens, + vllm_is_ratio_ptr, + ) + else: + _grpo_loss_fwd_kernel[grid]( + logits, + old_logp, + ref_logp, + completion_ids, + completion_mask, + advantages, + vllm_is_ratio_ptr, + vllm_is_ratio_stride, + loss, + lse, + kl, + is_clipped, + temperature, + beta, + eps_low, + eps_high, + loss_type_int, + sapo_temperature_pos, + sapo_temperature_neg, + delta_val, + use_bias_correction_kl, + L, + N, + BLOCK_N=block_n, + ) + ctx.save_for_backward( + logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse, mask, vllm_is_ratio_ptr + ) + + ctx.infos = ( + temperature, + beta, + eps_low, + eps_high, + inplace, + loss_type, + loss_type_int, + sapo_temperature_pos, + sapo_temperature_neg, + max_completion_length, + B, + L, + importance_sampling_level, + vllm_is_ratio_stride, + reduce, + delta_val, + use_bias_correction_kl, + ) + + mask_sum = mask.sum().clamp(min=1.0) + kl_mean = (kl * mask).sum() / mask_sum if kl is not None else None + clip_ratio = (is_clipped.float() * mask).sum() / mask_sum + + if not reduce: + loss_out = loss * mask + kl_out = kl * mask if kl is not None else None + is_clipped_out = is_clipped * mask + return loss_out, kl_out, is_clipped_out + + reduced_loss = reduce_loss(loss, mask, loss_type, max_completion_length, B, L) + return reduced_loss, kl_mean, clip_ratio + + +def grpo_loss_backward_triton(ctx, *args): + """Backward pass computation for GRPO loss.""" + dloss_input = args[0] + saved_tensors = ctx.saved_tensors + ( + temperature, + beta, + eps_low, + eps_high, + inplace, + loss_type, + loss_type_int, + sapo_temperature_pos, + sapo_temperature_neg, + max_completion_length, + B, + L, + importance_sampling_level, + vllm_is_ratio_stride, + reduce, + delta_val, + use_bias_correction_kl, + ) = ctx.infos + + if importance_sampling_level == "sequence": + ( + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + lse, + mask, + coef_1, + seq_lens, + vllm_is_ratio, + ) = saved_tensors + else: + (logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse, mask, vllm_is_ratio) = ( + saved_tensors + ) + + _, L_ADD_1, N = logits.shape + + if not reduce: + dloss = dloss_input + elif loss_type == "grpo" or loss_type == "sapo": + seq_lens_bwd = mask.sum(-1, keepdim=True).clamp(min=1.0) + dloss = dloss_input * mask / (seq_lens_bwd * B) + elif loss_type == "bnpo": + dloss = dloss_input * mask / mask.sum().clamp(min=1.0) + elif loss_type == "dr_grpo": + max_len = max_completion_length if max_completion_length is not None else L + dloss = dloss_input * mask / (B * max_len) + elif loss_type == "dapo" or loss_type == "cispo": + dloss = dloss_input * mask / compute_distribution_normalizer(mask) + elif loss_type == "luspo": + seq_lens_bwd = mask.sum(-1, keepdim=True).clamp(min=1.0) + dloss = dloss_input * seq_lens_bwd / (B * L) + else: + raise ValueError(f"Unknown loss_type: {loss_type}") + + dlogits = logits.data if inplace else torch.empty_like(logits) + + block_n = compute_block_size_backward(N) + num_cores = get_npu_core_count() + grid = calculate_tile_count_2d(B, L, num_cores) + + if importance_sampling_level == "sequence": + if vllm_is_ratio is None: + dloss_sum = dloss.sum(-1).contiguous() + else: + if vllm_is_ratio.dim() == 1: + ratio = vllm_is_ratio.unsqueeze(-1) + else: + ratio = vllm_is_ratio + dloss_sum = (dloss * ratio).sum(-1).contiguous() + _grpo_loss_bwd_kernel_seq[grid]( + dloss, + dloss_sum, + dlogits, + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + lse, + coef_1, + seq_lens, + temperature, + beta, + use_bias_correction_kl, + eps_low, + eps_high, + delta_val, + *dloss.stride(), + L, + N, + BLOCK_N=block_n, + ) + else: + _grpo_loss_bwd_kernel[grid]( + dloss, + dlogits, + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + lse, + vllm_is_ratio, + vllm_is_ratio_stride, + temperature, + beta, + eps_low, + eps_high, + loss_type_int, + sapo_temperature_pos, + sapo_temperature_neg, + delta_val, + use_bias_correction_kl, + *dloss.stride(), + L, + N, + BLOCK_N=block_n, + ) + + dlogits[:, -1, :] = 0 + return ( + dlogits, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +class GrpoLossFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, *args): + return grpo_loss_forward_triton(ctx, *args) + + @staticmethod + @ensure_contiguous + def backward(ctx, *args): + return grpo_loss_backward_triton(ctx, *args) diff --git a/src/liger_kernel/ops/backends/_ascend/ops/jsd.py b/src/liger_kernel/ops/backends/_ascend/ops/jsd.py new file mode 100755 index 0000000000000000000000000000000000000000..a28eecda1e71806b6010927b4453dbfb50e067dc --- /dev/null +++ b/src/liger_kernel/ops/backends/_ascend/ops/jsd.py @@ -0,0 +1,229 @@ +from typing import Optional + +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy +from liger_kernel.ops.utils import ensure_contiguous +from liger_kernel.ops.utils import get_npu_core_count + + +@triton.jit +def _jsd_kernel( + X_ptr, # input in logspace, X = log Q + X_stride, + Y_ptr, # ground truth in logspace, Y = log P + Y_stride, + loss_ptr, + loss_stride, + dX_ptr, + dX_stride, + label_ptr, + beta: tl.constexpr, + n_non_ignore: int, + ignore_index: tl.constexpr, + n_rows: tl.constexpr, + n_cols: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + HAS_LABEL: tl.constexpr, +): + # JSD(P || Q) = (KL(P || M) + KL(Q || M)) / 2, M = (1/2) * (P + Q) = (1/2) * (e ^ Y + e ^ X) + # = sum(P * log P + Q * log Q - 2 * M * log M) / 2 + # = sum(e ^ Y * Y + e ^ X * X - 2 * M * log M) / 2 + # grad_x_i = 0.5 * Q * (X - log_M) + + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + + # Grid-Stride Loop - each kernel processes multiple rows + for row_idx in range(pid, n_rows, num_progs): + X_row_ptr = X_ptr + row_idx * X_stride + Y_row_ptr = Y_ptr + row_idx * Y_stride + loss_row_ptr = loss_ptr + row_idx * loss_stride + dX_row_ptr = dX_ptr + row_idx * dX_stride + + should_skip = False + if HAS_LABEL: + label = tl.load(label_ptr + row_idx) + should_skip = label == ignore_index + + if should_skip: + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + tl.store(dX_row_ptr + offsets, 0.0, mask=mask) + tl.store(loss_row_ptr + offsets, 0.0, mask=mask) + else: + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + X = tl.load(X_row_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32) + Y = tl.load(Y_row_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32) + + if beta == 0.0: # forward KL + Y_max = tl.max(Y, axis=0) + Y_shifted = Y - Y_max + Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift + loss = Y_prob * (Y - X) + dX = -Y_prob + elif beta == 1.0: # reverse KL + X_max = tl.max(X, axis=0) + X_shifted = X - X_max + X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift + loss = X_prob * (X - Y) + dX = loss + X_prob + else: + max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0)) + X_shifted = X - max_val + Y_shifted = Y - max_val + + # Pre-compute exp(max_val) since it's used twice + exp_max = tl.exp(max_val) + + # Compute exp terms with compensation + Q = tl.exp(X_shifted) * exp_max # = exp(X) + P = tl.exp(Y_shifted) * exp_max # = exp(Y) + + # Pre-compute common terms + beta_P = beta * P + one_minus_beta_Q = (1 - beta) * Q + M = beta_P + one_minus_beta_Q + log_M = tl.log(M) + + loss = beta_P * Y + one_minus_beta_Q * X - M * log_M + dX = one_minus_beta_Q * (X - log_M) + + # Pre-compute scaling factor + scale = 1.0 / n_non_ignore + loss = loss * scale + dX = dX * scale + + tl.store(loss_row_ptr + offsets, loss, mask=mask) + tl.store(dX_row_ptr + offsets, dX, mask=mask) + + +def get_optimal_block_size(total_elements): + """ + Calculate optimal Block Size using compute_default_tiling_strategy + """ + tile_shapes = compute_default_tiling_strategy( + safety_margin=0.9, dtype_size=4, memory_multiplier=8.0, shapes=((total_elements,),), tiling_dims=(0,) + ) + + if tile_shapes and len(tile_shapes) > 0: + block_size = tile_shapes[0][0] + return block_size + else: + return 2048 + + +def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label): + BT, V = _input.shape + n_rows = BT + BLOCK_SIZE = get_optimal_block_size(V) + + # non reduction loss + loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device) + dX = torch.empty_like(_input) + + if has_label: + n_non_ignore = (shift_labels != ignore_index).sum().item() + else: + n_non_ignore = BT + + # Use NPU core count for grid size + num_cores = get_npu_core_count() + grid_size = min(num_cores, n_rows) + + _jsd_kernel[(grid_size,)]( + X_ptr=_input, + X_stride=_input.stride(-2), + Y_ptr=target, + Y_stride=target.stride(-2), + loss_ptr=loss, + loss_stride=loss.stride(-2), + dX_ptr=dX, + dX_stride=dX.stride(-2), + label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)), + beta=beta, + n_non_ignore=n_non_ignore, + ignore_index=ignore_index, + n_rows=n_rows, + n_cols=V, + BLOCK_SIZE=BLOCK_SIZE, + HAS_LABEL=has_label, + ) + + loss = torch.sum(loss) + return loss.to(_input.dtype), dX + + +def jsd_backward(dX, grad_output): + # If jsd is the last layer, grad_output is 1.0. Skip the mul to save time + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + return dX + else: + return grad_output * dX + + +class LigerJSDFunction(torch.autograd.Function): + r""" + This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence. + .. math:: + JSD(\beta)(P || Q) + = \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q)) + + .. note:: + As all the other losses in PyTorch, this function expects the first argument, + :attr:`_input`, to be the predictions, the output of the student model, in log-space + and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space. + This differs from the standard mathematical notation :math:`JSD(P || Q)` where + :math:`P` denotes the teacher model and :math:`Q` denotes the student model. + """ + + @staticmethod + @ensure_contiguous + def forward( + ctx, + _input: torch.Tensor, + target: torch.Tensor, + shift_labels: Optional[torch.Tensor] = None, + beta: float = 0.5, + ignore_index: int = -100, + ) -> torch.Tensor: + """ + Args: + _input (torch.Tensor): predict values with shape (BT, V) in logspace + target (torch.Tensor): ground truth values with shape (BT, V) in logspace + shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1]. + beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5` + ignore_index (int): the index to ignore. Default: -100 + + Returns: + loss (torch.Tensor): generalized JSD + """ + has_label = False + if shift_labels is not None: + assert shift_labels.shape == (_input.shape[0],), ( + f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}" + ) + shift_labels = shift_labels.contiguous() + has_label = True + + loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label) + ctx.save_for_backward(dX) + return loss + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + (dX,) = ctx.saved_tensors + dX = jsd_backward(dX, grad_output) + return ( + dX, + None, + None, + None, + None, + ) diff --git a/src/liger_kernel/ops/backends/_ascend/ops/kl_div.py b/src/liger_kernel/ops/backends/_ascend/ops/kl_div.py new file mode 100755 index 0000000000000000000000000000000000000000..f7b2614f692f712cf3f0ee320dda77adf832ff86 --- /dev/null +++ b/src/liger_kernel/ops/backends/_ascend/ops/kl_div.py @@ -0,0 +1,327 @@ +from typing import Literal + +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy +from liger_kernel.ops.utils import ensure_contiguous +from liger_kernel.ops.utils import get_npu_core_count + +REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"] + +_REDUCTION_MODE_NONE: tl.constexpr = tl.constexpr(0) +_REDUCTION_MODE_SUM: tl.constexpr = tl.constexpr(1) +_REDUCTION_MODE_MEAN: tl.constexpr = tl.constexpr(2) +_REDUCTION_MODE_BATCHMEAN: tl.constexpr = tl.constexpr(3) + +_str_to_reduction_mode = { + "none": _REDUCTION_MODE_NONE.value, + "sum": _REDUCTION_MODE_SUM.value, + "mean": _REDUCTION_MODE_MEAN.value, + "batchmean": _REDUCTION_MODE_BATCHMEAN.value, +} + +# ----------------------------------------------------------------------------- +# Kernels (2D Tiling + Persistent Programs) +# ----------------------------------------------------------------------------- + + +@triton.jit +def _kldiv_kernel_forward( + y_ptr, # [B, S], prediction ptr, the kernel expects the prediction in log-space + gt_ptr, # [B, S], ground truth ptr + loss_ptr, # [B] or [B, S] if reduction == _REDUCTION_MODE_NONE, output ptr + n_rows, # int, number of rows in the input tensor + n_cols, # int, number of columns in the input tensor + eps, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + log_target: tl.constexpr = False, + reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN, +): + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + + grid_m = tl.cdiv(n_rows, BLOCK_SIZE_M) + grid_n = tl.cdiv(n_cols, BLOCK_SIZE_N) + total_2d_blocks = grid_m * grid_n + + # Persistent-program loop over logical 2D blocks. + for block_idx in tl.range(pid, total_2d_blocks, num_progs): + block_m = block_idx // grid_n + block_n = block_idx % grid_n + + offset_m = tl.arange(0, BLOCK_SIZE_M) + block_m * BLOCK_SIZE_M + offset_n = tl.arange(0, BLOCK_SIZE_N) + block_n * BLOCK_SIZE_N + + mask_m = offset_m < n_rows + mask_n = offset_n < n_cols + + offset = offset_m[:, None] * n_cols + offset_n[None, :] + mask = mask_m[:, None] & mask_n[None, :] + + y = tl.load(y_ptr + offset, mask=mask, other=0.0) + y_true = tl.load(gt_ptr + offset, mask=mask, other=0.0) + + # KL(y_true || y_pred) with y_pred provided in log-space. + # - log_target=False: y_true is probability space; clamp with eps before log. + # - log_target=True : y_true is log-probability space. + if log_target: + loss = tl.exp(y_true) * (y_true - y) + else: + loss = y_true * (tl.log(tl.maximum(y_true, eps)) - y) + + if reduction == _REDUCTION_MODE_NONE: + tl.store(loss_ptr + offset, loss, mask=mask) + else: + # Multiple block_n tiles may update the same row, so atomic_add is required. + loss_sum = tl.sum(loss, axis=1) + tl.atomic_add(loss_ptr + offset_m, loss_sum, mask=mask_m) + + +@triton.jit +def _kldiv_kernel_backward( + target_ptr, + new_grads_ptr, + grad_output_ptr, + n_rows, + n_cols, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + log_target: tl.constexpr = False, + reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN, +): + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + + grid_m = tl.cdiv(n_rows, BLOCK_SIZE_M) + grid_n = tl.cdiv(n_cols, BLOCK_SIZE_N) + total_2d_blocks = grid_m * grid_n + + # For reduced losses, grad_output is a scalar. Load it once per program. + if reduction != _REDUCTION_MODE_NONE: + grad_output_scalar = tl.load(grad_output_ptr) + + # Persistent-program loop over logical 2D blocks. + for block_idx in tl.range(pid, total_2d_blocks, num_progs): + block_m = block_idx // grid_n + block_n = block_idx % grid_n + + offset_m = tl.arange(0, BLOCK_SIZE_M) + block_m * BLOCK_SIZE_M + offset_n = tl.arange(0, BLOCK_SIZE_N) + block_n * BLOCK_SIZE_N + + mask_m = offset_m < n_rows + mask_n = offset_n < n_cols + + offset = offset_m[:, None] * n_cols + offset_n[None, :] + mask = mask_m[:, None] & mask_n[None, :] + + y_true = tl.load(target_ptr + offset, mask=mask, other=0.0) + + if log_target: + res = -tl.exp(y_true) + else: + res = y_true * -1 + + if reduction != _REDUCTION_MODE_NONE: + res = res * grad_output_scalar + else: + grad_output = tl.load(grad_output_ptr + offset, mask=mask, other=0.0) + res = res * grad_output + + if reduction == _REDUCTION_MODE_BATCHMEAN: + res = res / n_rows + elif reduction == _REDUCTION_MODE_MEAN: + res = res / (n_rows * n_cols) + + tl.store(new_grads_ptr + offset, res, mask=mask) + + +# ----------------------------------------------------------------------------- +# Helper: Call compute_default_tiling_strategy +# ----------------------------------------------------------------------------- + + +def get_optimal_block_size( + n_rows, + dtype_size, + BLOCK_SIZE_N: tl.constexpr, + log_target: bool = False, + is_backward: bool = False, + is_scalar_grad_output: bool = True, +): + """ + Calculate optimal BLOCK_SIZE_M using compute_default_tiling_strategy. + """ + # 1) Set memory multiplier + # Backward is lighter than forward in this op, so we typically use a smaller multiplier. + # If backward also needs to stream a full grad_output tile (i.e., grad_output is not a scalar), + # its memory footprint becomes closer to forward, so we bump the multiplier. + if is_backward: + multiplier = 2.5 if is_scalar_grad_output else 3.0 + else: + multiplier = 3.0 if log_target else 6.0 + + # For bf16/fp16 (dtype_size < 4), compile-time UB overflow was observed on some shapes. + # Clamp to fp32 size for a conservative tiling estimate; this can be refined later. + dtype_size = max(dtype_size, 4) + + # 2) Call tiling strategy (tile only dim 0 / rows) + tile_shapes = compute_default_tiling_strategy( + safety_margin=0.9, + dtype_size=dtype_size, + memory_multiplier=multiplier, + shapes=((n_rows, BLOCK_SIZE_N),), + tiling_dims=(0,), + ) + + # 3) Parse result + if tile_shapes and len(tile_shapes) > 0: + block_size = tile_shapes[0][0] + return block_size + else: + return triton.next_power_of_2(min(128, n_rows)) + + +def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V] + BT, V = y_pred.shape + reduction = _str_to_reduction_mode[reduction] + + out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,) + output_tensor = torch.zeros(out_size, device=y_pred.device, dtype=torch.float32) + + BLOCK_SIZE_N = triton.next_power_of_2(min(128, V)) + BLOCK_SIZE_M = get_optimal_block_size(BT, y_pred.element_size(), BLOCK_SIZE_N, log_target=log_target) + num_cores = get_npu_core_count() + total_blocks = triton.cdiv(BT, BLOCK_SIZE_M) * triton.cdiv(V, BLOCK_SIZE_N) + grid = min(num_cores, total_blocks) + + _kldiv_kernel_forward[(grid,)]( + y_pred, + y_true, + output_tensor, + BT, + V, + eps=eps, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + log_target=log_target, + reduction=reduction, + ) + + # Final reduction follows PyTorch KLDivLoss semantics. + # Note: In newer PyTorch versions, `mean` is planned to match `batchmean`. + # See: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html + if reduction == _REDUCTION_MODE_BATCHMEAN.value: + return output_tensor.sum() / BT + elif reduction == _REDUCTION_MODE_SUM.value: + return output_tensor.sum(dim=0) + elif reduction == _REDUCTION_MODE_MEAN.value: + return output_tensor.sum() / (BT * V) + else: + return output_tensor + + +def kldiv_backward_triton(target, grad_output, new_grads, log_target, reduction): + BT, V = target.shape + reduction = _str_to_reduction_mode[reduction] + + BLOCK_SIZE_N = triton.next_power_of_2(min(128, V)) + # grad_output handling: + # - numel() == 1: use scalar grad_output path in kernel. + # - numel() != 1: stream per-element grad_output tile in kernel. + is_scalar_grad_output = grad_output.numel() == 1 + BLOCK_SIZE_M = get_optimal_block_size( + BT, + target.element_size(), + BLOCK_SIZE_N, + log_target=log_target, + is_backward=True, + is_scalar_grad_output=is_scalar_grad_output, + ) + num_cores = get_npu_core_count() + total_blocks = triton.cdiv(BT, BLOCK_SIZE_M) * triton.cdiv(V, BLOCK_SIZE_N) + grid = min(num_cores, total_blocks) + + _kldiv_kernel_backward[(grid,)]( + target, + new_grads, + grad_output, + BT, + V, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + log_target=log_target, + reduction=reduction, + ) + + return new_grads + + +class LigerKLDivLossFunction(torch.autograd.Function): + """ + Class implementing the forward and backward pass for the KL Divergence Loss using Triton, as defined by the following formula: + ```python + if log_target: + loss = target.exp() * (target - input) + else: + loss = target * (target.log() - input) + ```, + then the loss is reduced according to the `reduction` parameter. + as defined in the PyTorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html + """ + + @staticmethod + @ensure_contiguous + def forward( + ctx, + y_pred: torch.Tensor, + y_true: torch.Tensor, + reduction: REDUCTION_LITERAL = "batchmean", + log_target: bool = False, + eps: float = 1e-10, + ) -> torch.Tensor: + """A forward pass for the KL Divergence Loss. + + Args: + ctx: Torch autograd context + y_pred (torch.Tensor): A tensor of shape (BT, V) containing the predicted values, expected to be log-probabilities. + y_true (torch.Tensor): A tensor of shape (BT, V) containing the target values, expected to be either probabilities or log-probabilities, depending on the value of `log_target`. + reduction (REDUCTION_LITERAL, optional): Reduction to be used. Defaults to "batchmean". + log_target (bool, optional): If set to true, expects the ground truth to already be log-probabilities. Defaults to False. + eps: (float, optional): A small value to avoid division by zero. Defaults to 1e-10. + + Returns: + torch.Tensor: The computed KL Divergence Loss, with shape (BT, V) if `reduction` is "none", else a scalar. + """ + ctx.save_for_backward(y_true) + ctx.reduction = reduction + ctx.log_target = log_target + return kldiv_forward_triton(y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps) + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + """A backward pass for the KL Divergence Loss. + + Args: + ctx: Torch autograd context + grad_output (torch.Tensor): The gradient of the loss with respect to the output. + + Returns: + tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs and None for the other arguments of the forward method. + """ + (y_true,) = ctx.saved_tensors + + new_grads = torch.empty_like(y_true) + + derivative = kldiv_backward_triton(y_true, grad_output, new_grads, ctx.log_target, ctx.reduction) + + return ( + derivative, + None, + None, + None, + None, + ) diff --git a/src/liger_kernel/ops/backends/_ascend/ops/layer_norm.py b/src/liger_kernel/ops/backends/_ascend/ops/layer_norm.py new file mode 100755 index 0000000000000000000000000000000000000000..6e82026d49d02b958627a21d9d2cc93a0df6bf26 --- /dev/null +++ b/src/liger_kernel/ops/backends/_ascend/ops/layer_norm.py @@ -0,0 +1,642 @@ +import torch +import triton +import triton.language as tl + +from triton.language.math import rsqrt + +from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy +from liger_kernel.ops.utils import ensure_contiguous +from liger_kernel.ops.utils import get_npu_core_count + +# ----------------------------------------------------------------------------- +# Optimized Forward Kernel - No Tiling (for n_cols <= 2048) +# ----------------------------------------------------------------------------- + + +@triton.jit +def _layer_norm_forward_kernel_no_tiling( + Y_ptr, + Y_row_stride, + X_ptr, + X_row_stride, + W_ptr, + B_ptr, + Mean_ptr, + Mean_row_stride, + RSTD_ptr, + RSTD_row_stride, + n_rows: tl.constexpr, + n_cols: tl.constexpr, + eps: tl.constexpr, + n_cols_inv: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + """ + OPTIMIZED NPU layer_norm forward kernel for small n_cols (<= 2048). + + Key optimizations: + 1. Pre-compute n_cols_inv to avoid repeated scalar division + 2. Hoist W and B loads outside the loop (already done) + 3. Minimize per-iteration scalar operations + 4. Use vectorized operations for mask handling + 5. Optimize cache hints for memory access patterns + 6. Reduce type conversions by keeping intermediate results in float32 + """ + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + + # Pre-compute grid stride constants (done once, not per iteration) + grid_stride = num_progs * BLOCK_SIZE_M + num_iterations = tl.cdiv(n_rows, grid_stride) + + col_offsets = tl.arange(0, BLOCK_SIZE_N) + col_mask = col_offsets < n_cols + row_offsets = tl.arange(0, BLOCK_SIZE_M) + + # Load W and B once (already optimized - kept outside loop) + W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0).to(tl.float32) + B_row = tl.load(B_ptr + col_offsets, mask=col_mask, other=0.0).to(tl.float32) + + base_row_idx = pid * BLOCK_SIZE_M + + # Grid-stride loop over row blocks + for i in range(num_iterations): + row_idx = i * grid_stride + base_row_idx + row_offsets + row_mask = row_idx < n_rows + + block_mask = row_mask[:, None] & col_mask[None, :] + + X_block_ptr = X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :] + + X_rows = tl.load( + X_block_ptr, + mask=block_mask, + other=0.0, + cache_modifier=".cg", + ).to(tl.float32) + + # Compute mean with vectorized operations + row_sum = tl.sum(X_rows, axis=1) + mean_rows = row_sum * n_cols_inv # Multiplication is faster than division + + # Center the data (vectorized operation) + X_centered = X_rows - mean_rows[:, None] + + X_centered_masked = tl.where(block_mask, X_centered, 0.0) + var_rows = tl.sum(X_centered_masked * X_centered_masked, axis=1) * n_cols_inv + + rstd_rows = rsqrt(var_rows + eps) + + Mean_ptr_offset = Mean_ptr + row_idx * Mean_row_stride + RSTD_ptr_offset = RSTD_ptr + row_idx * RSTD_row_stride + + tl.store(Mean_ptr_offset, mean_rows, mask=row_mask) + tl.store(RSTD_ptr_offset, rstd_rows, mask=row_mask) + + Y_f32 = X_centered * rstd_rows[:, None] * W_row[None, :] + B_row[None, :] + + # Store output with coalesced memory access + Y_block_ptr = Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :] + tl.store(Y_block_ptr, Y_f32, mask=block_mask) + + +# ----------------------------------------------------------------------------- +# Forward Kernel - With Tiling (for n_cols > 2048) +# ----------------------------------------------------------------------------- + + +@triton.jit +def _layer_norm_forward_kernel_npu( + Y_ptr, + Y_row_stride, + X_ptr, + X_row_stride, + W_ptr, + B_ptr, + Mean_ptr, + Mean_row_stride, + RSTD_ptr, + RSTD_row_stride, + n_rows: tl.constexpr, + n_cols: tl.constexpr, + eps, + BLOCK_SIZE: tl.constexpr, +): + """NPU-optimized layer_norm forward kernel with column blocking.""" + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + num_col_blocks = tl.cdiv(n_cols, BLOCK_SIZE) + + offsets = tl.arange(0, BLOCK_SIZE) + n_cols_inv = 1.0 / n_cols + + for row_idx in range(pid, n_rows, num_progs): + Y_row_ptr = Y_ptr + row_idx * Y_row_stride + X_row_ptr = X_ptr + row_idx * X_row_stride + Mean_row_ptr = Mean_ptr + row_idx * Mean_row_stride + RSTD_row_ptr = RSTD_ptr + row_idx * RSTD_row_stride + + row_sum = 0.0 + + for col_block_idx in range(num_col_blocks): + col_start = col_block_idx * BLOCK_SIZE + col_offsets = col_start + offsets + mask = col_offsets < n_cols + + X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32) + + row_sum += tl.sum(X_block) + + mean = row_sum * n_cols_inv + + var_sum = 0.0 + + for col_block_idx in range(num_col_blocks): + col_start = col_block_idx * BLOCK_SIZE + col_offsets = col_start + offsets + mask = col_offsets < n_cols + + X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32) + + X_centered = X_block - mean + var_sum += tl.sum(tl.where(mask, X_centered * X_centered, 0.0)) + + var = var_sum * n_cols_inv + rstd = rsqrt(var + eps) + + tl.store(Mean_row_ptr, mean) + tl.store(RSTD_row_ptr, rstd) + + for col_block_idx in range(num_col_blocks): + col_start = col_block_idx * BLOCK_SIZE + col_offsets = col_start + offsets + mask = col_offsets < n_cols + + X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0, cache_modifier=".ca").to(tl.float32) + W_block = tl.load(W_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + B_block = tl.load(B_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + + X_centered = X_block - mean + Y_f32 = X_centered * rstd * W_block + B_block + + tl.store(Y_row_ptr + col_offsets, Y_f32.to(X_block.dtype), mask=mask) + + +# ----------------------------------------------------------------------------- +# Optimized Backward Kernel - No Tiling (for n_cols <= 2048) +# ----------------------------------------------------------------------------- + + +@triton.jit +def _layer_norm_backward_kernel_no_tiling( + X_ptr, + X_row_stride, + W_ptr, + Mean_ptr, + Mean_row_stride, + RSTD_ptr, + RSTD_row_stride, + DX_ptr, + DX_row_stride, + DW_scratch_ptr, + DW_scratch_stride, + DB_scratch_ptr, + DB_scratch_stride, + DY_ptr, + DY_row_stride, + n_rows: tl.constexpr, + n_cols: tl.constexpr, + n_cols_inv: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + """ + OPTIMIZED NPU layer_norm backward kernel for small n_cols (<= 2048). + + Key optimizations: + 1. Pre-compute n_cols_inv to avoid repeated division + 2. Minimize scalar operations in the hot path + 3. Reduce redundant mask computations + 4. Optimize memory access patterns with better cache hints + 5. Keep intermediate results in float32 to reduce conversions + 6. Use vectorized operations throughout + """ + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + + grid_stride = num_progs * BLOCK_SIZE_M + num_iterations = tl.cdiv(n_rows, grid_stride) + + col_offsets = tl.arange(0, BLOCK_SIZE_N) + col_mask = col_offsets < n_cols + row_offsets = tl.arange(0, BLOCK_SIZE_M) + + W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0).to(tl.float32) + + # Per-program accumulators for dW/dB + dW_acc = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32) + dB_acc = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32) + + base_row_idx = pid * BLOCK_SIZE_M + + # Grid-stride loop over row blocks + for i in range(num_iterations): + row_idx = i * grid_stride + base_row_idx + row_offsets + row_mask = row_idx < n_rows + + # Pre-compute block mask once + block_mask = row_mask[:, None] & col_mask[None, :] + + X_block_ptr = X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :] + DY_block_ptr = DY_ptr + row_idx[:, None] * DY_row_stride + col_offsets[None, :] + Mean_row_ptr = Mean_ptr + row_idx * Mean_row_stride + RSTD_row_ptr = RSTD_ptr + row_idx * RSTD_row_stride + + # Load all required data with appropriate cache hints + # .cg = cache global (read once, don't pollute cache) + X_rows = tl.load(X_block_ptr, mask=block_mask, other=0.0, cache_modifier=".cg").to(tl.float32) + DY_rows = tl.load(DY_block_ptr, mask=block_mask, other=0.0, cache_modifier=".cg").to(tl.float32) + mean_rows = tl.load(Mean_row_ptr, mask=row_mask, other=0.0).to(tl.float32) + rstd_rows = tl.load(RSTD_row_ptr, mask=row_mask, other=0.0).to(tl.float32) + + x_hat = (X_rows - mean_rows[:, None]) * rstd_rows[:, None] + wdy = W_row[None, :] * DY_rows + + x_hat_wdy_masked = tl.where(block_mask, x_hat * wdy, 0.0) + wdy_masked = tl.where(block_mask, wdy, 0.0) + + c1 = tl.sum(x_hat_wdy_masked, axis=1) * n_cols_inv + c2 = tl.sum(wdy_masked, axis=1) * n_cols_inv + + DX_f32 = (wdy - (x_hat * c1[:, None] + c2[:, None])) * rstd_rows[:, None] + + # Store dX with coalesced memory access + DX_block_ptr = DX_ptr + row_idx[:, None] * DX_row_stride + col_offsets[None, :] + tl.store(DX_block_ptr, DX_f32.to(X_ptr.dtype.element_ty), mask=block_mask) + + dW_acc += tl.sum(tl.where(block_mask, DY_rows * x_hat, 0.0), axis=0) + dB_acc += tl.sum(tl.where(block_mask, DY_rows, 0.0), axis=0) + + # Write accumulated gradients to scratch buffers + DW_scratch_offset = DW_scratch_ptr + pid * DW_scratch_stride + col_offsets + DB_scratch_offset = DB_scratch_ptr + pid * DB_scratch_stride + col_offsets + + tl.store(DW_scratch_offset, dW_acc, mask=col_mask) + tl.store(DB_scratch_offset, dB_acc, mask=col_mask) + + +# ----------------------------------------------------------------------------- +# Backward Kernel - With Tiling (for n_cols > 2048) +# ----------------------------------------------------------------------------- + + +@triton.jit +def _layer_norm_backward_kernel_npu( + X_ptr, + X_row_stride, + W_ptr, + Mean_ptr, + Mean_row_stride, + RSTD_ptr, + RSTD_row_stride, + DX_ptr, + DX_row_stride, + DW_ptr, + DB_ptr, + DY_ptr, + DY_row_stride, + n_rows: tl.constexpr, + n_cols: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """NPU-optimized layer_norm backward kernel with column blocking.""" + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + num_col_blocks = tl.cdiv(n_cols, BLOCK_SIZE) + + offsets = tl.arange(0, BLOCK_SIZE) + n_cols_inv = 1.0 / n_cols + + for row_idx in range(pid, n_rows, num_progs): + X_row_ptr = X_ptr + row_idx * X_row_stride + DY_row_ptr = DY_ptr + row_idx * DY_row_stride + DX_row_ptr = DX_ptr + row_idx * DX_row_stride + Mean_row_ptr = Mean_ptr + row_idx * Mean_row_stride + RSTD_row_ptr = RSTD_ptr + row_idx * RSTD_row_stride + + mean = tl.load(Mean_row_ptr).to(tl.float32) + rstd = tl.load(RSTD_row_ptr).to(tl.float32) + + sum_x_hat_wdy = 0.0 + sum_wdy = 0.0 + + for col_block_idx in range(num_col_blocks): + col_start = col_block_idx * BLOCK_SIZE + col_offsets = col_start + offsets + mask = col_offsets < n_cols + + X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + DY_block = tl.load(DY_row_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + W_block = tl.load(W_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + + x_hat = (X_block - mean) * rstd + wdy = W_block * DY_block + + sum_x_hat_wdy += tl.sum(tl.where(mask, x_hat * wdy, 0.0)) + sum_wdy += tl.sum(tl.where(mask, wdy, 0.0)) + + c1 = sum_x_hat_wdy * n_cols_inv + c2 = sum_wdy * n_cols_inv + + for col_block_idx in range(num_col_blocks): + col_start = col_block_idx * BLOCK_SIZE + col_offsets = col_start + offsets + mask = col_offsets < n_cols + + X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + DY_block = tl.load(DY_row_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + W_block = tl.load(W_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + + x_hat = (X_block - mean) * rstd + wdy = W_block * DY_block + + DX_block = (wdy - (x_hat * c1 + c2)) * rstd + tl.store(DX_row_ptr + col_offsets, DX_block.to(X_ptr.dtype.element_ty), mask=mask) + + dW_block = DY_block * x_hat + dB_block = DY_block + + tl.atomic_add(DW_ptr + col_offsets, dW_block, mask=mask) + tl.atomic_add(DB_ptr + col_offsets, dB_block, mask=mask) + + +# ----------------------------------------------------------------------------- +# Helper Functions +# ----------------------------------------------------------------------------- + + +def get_optimal_block_size(n_cols, is_forward: bool): + """ + Calculate optimal block size using compute_default_tiling_strategy. + + Memory analysis for forward pass (per row): + - Load: X_block, W_block, B_block (3 blocks) + - Store: Y_block, Mean, RSTD (3 blocks) + - Compute: X_centered, Y intermediate (2 blocks) + - Total: conservative estimate 10 blocks of memory + + Memory analysis for backward pass (per row): + - Load: X_block, DY_block, W_block, Mean, RSTD, existing_DW, existing_DB (7 blocks) + - Store: DX_block, new_DW, new_DB (3 blocks) + - Compute: x_hat, wdy, DX intermediate, dW_block, dB_block (5 blocks) + - Total: conservative estimate 15 blocks of memory + + Args: + n_cols: Number of columns in the tensor + is_forward: Whether this is for forward pass (True) or backward pass (False) + + Returns: + Optimal block size + """ + if n_cols <= 2048: + return triton.next_power_of_2(n_cols) + + memory_multiplier = 10.0 if is_forward else 15.0 + + tile_shapes = compute_default_tiling_strategy( + safety_margin=0.9, + dtype_size=4, + memory_multiplier=memory_multiplier, + shapes=((n_cols,),), + tiling_dims=(0,), + ) + + if tile_shapes and len(tile_shapes) > 0: + block_size = tile_shapes[0][0] + return max(2048, block_size) + else: + return 2048 + + +def _compute_grid_size(n_rows: int, block_size_m: int, num_cores: int) -> int: + """ + Compute the effective grid size for no-tiling kernels. + + OPTIMIZATION: Balances parallelism with overhead + - Ensures enough work per program to amortize launch costs + - Avoids launching idle programs + - Caps at 2x core count for hardware concurrency + """ + num_row_blocks = triton.cdiv(n_rows, block_size_m) + + return min(num_cores * 2, num_row_blocks) + + +# ----------------------------------------------------------------------------- +# Forward and Backward Functions +# ----------------------------------------------------------------------------- + + +def layer_norm_forward(X, W, B, eps): + """ + NPU-optimized forward pass for LayerNorm. + + Args: + X: Input tensor of shape (..., hidden_size) + W: Weight tensor of shape (hidden_size,) + B: Bias tensor of shape (hidden_size,) + eps: Small constant for numerical stability + + Returns: + Tuple of (output, input, mean, rstd) + """ + shape = X.shape + dim = shape[-1] + X = X.view(-1, dim) + n_rows, n_cols = X.shape + + if X.shape[1] != W.shape[0]: + raise ValueError( + f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) " + f"must match weight size (W.shape[0]={W.shape[0]})" + ) + + # Get optimal block sizes + BLOCK_SIZE = get_optimal_block_size(n_cols, True) + BLOCK_SIZE_M = 2048 // BLOCK_SIZE + + # Allocate output tensors + Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) + Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device) + RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device) + + num_cores = get_npu_core_count() + + # Choose kernel + if n_cols <= 2048: + grid_size = _compute_grid_size(n_rows, BLOCK_SIZE_M, num_cores) + n_cols_inv = 1.0 / float(n_cols) + + _layer_norm_forward_kernel_no_tiling[(grid_size,)]( + Y, + Y.stride(0), + X, + X.stride(0), + W, + B, + Mean, + Mean.stride(0), + RSTD, + RSTD.stride(0), + n_rows, + n_cols, + eps, + n_cols_inv, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE, + ) + else: + grid_size = min(num_cores, n_rows) + _layer_norm_forward_kernel_npu[(grid_size,)]( + Y, + Y.stride(0), + X, + X.stride(0), + W, + B, + Mean, + Mean.stride(0), + RSTD, + RSTD.stride(0), + n_rows, + n_cols, + eps, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return Y.view(*shape), X, Mean, RSTD + + +def layer_norm_backward(dY, X, W, B, Mean, RSTD): + """ + NPU-optimized backward pass for LayerNorm. + + Args: + dY: Gradient of output + X: Input tensor + W: Weight tensor + B: Bias tensor + Mean: Pre-computed mean + RSTD: Pre-computed reciprocal standard deviation + + Returns: + Tuple of (input_grad, weight_grad, bias_grad) + """ + shape = dY.shape + dim = shape[-1] + dY = dY.view(-1, dim) + n_rows, n_cols = dY.shape + + # Get optimal block sizes + BLOCK_SIZE = get_optimal_block_size(n_cols, False) + BLOCK_SIZE_M = 2048 // BLOCK_SIZE + + num_cores = get_npu_core_count() + + # Allocate gradient tensors + DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) + + # Choose kernel + if n_cols <= 2048: + grid_size = _compute_grid_size(n_rows, BLOCK_SIZE_M, num_cores) + DW_scratch = torch.empty((grid_size, n_cols), dtype=torch.float32, device=W.device) + DB_scratch = torch.empty((grid_size, n_cols), dtype=torch.float32, device=W.device) + + n_cols_inv = 1.0 / float(n_cols) + + _layer_norm_backward_kernel_no_tiling[(grid_size,)]( + X, + X.stride(0), + W, + Mean, + Mean.stride(0), + RSTD, + RSTD.stride(0), + DX, + DX.stride(0), + DW_scratch, + DW_scratch.stride(0), + DB_scratch, + DB_scratch.stride(0), + dY, + dY.stride(0), + n_rows, + n_cols, + n_cols_inv, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE, + ) + + DW = DW_scratch.sum(dim=0) + DB = DB_scratch.sum(dim=0) + else: + grid_size = min(num_cores, n_rows) + + DW = torch.zeros(n_cols, dtype=torch.float32, device=W.device) + DB = torch.zeros(n_cols, dtype=torch.float32, device=W.device) + + _layer_norm_backward_kernel_npu[(grid_size,)]( + X, + X.stride(0), + W, + Mean, + Mean.stride(0), + RSTD, + RSTD.stride(0), + DX, + DX.stride(0), + DW, + DB, + dY, + dY.stride(0), + n_rows, + n_cols, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return DX.view(*shape), DW.to(W.dtype), DB.to(B.dtype) + + +# ----------------------------------------------------------------------------- +# Autograd Function +# ----------------------------------------------------------------------------- + + +class LigerLayerNormFunction(torch.autograd.Function): + """ + OPTIMIZED NPU LayerNorm operation. + + Key optimizations for no-tiling kernels: + 1. Pre-compute 1/n_cols to avoid scalar division (40.6% → <30% target) + 2. Minimize per-iteration scalar operations in grid-stride loops + 3. Hoist constant computations outside loops + 4. Use vectorized operations throughout + 5. Optimize memory access patterns with better cache hints + 6. Reduce type conversions by keeping intermediates in float32 + 7. Improve grid sizing for better work distribution + """ + + @staticmethod + @ensure_contiguous + def forward(ctx, X, W, B, eps): + Y, X, Mean, RSTD = layer_norm_forward(X, W, B, eps) + ctx.save_for_backward(X, W, B, Mean, RSTD) + return Y + + @staticmethod + @ensure_contiguous + def backward(ctx, dY): + X, W, B, Mean, RSTD = ctx.saved_tensors + DX, DW, DB = layer_norm_backward(dY, X, W, B, Mean, RSTD) + return DX, DW, DB, None diff --git a/src/liger_kernel/ops/backends/_ascend/ops/llama4_rope.py b/src/liger_kernel/ops/backends/_ascend/ops/llama4_rope.py new file mode 100755 index 0000000000000000000000000000000000000000..437f81a6e2a4c3198902cd616abd1f16739cecc2 --- /dev/null +++ b/src/liger_kernel/ops/backends/_ascend/ops/llama4_rope.py @@ -0,0 +1,306 @@ +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy + + +def _cast_and_contiguous(q, k, freqs_complex): + # Align dtype: fp32 only when q is fp32; otherwise keep q dtype for perf + compute_dtype = torch.float32 if q.dtype == torch.float32 else q.dtype + + if k.dtype != q.dtype: + k = k.to(q.dtype) + + q = q.to(compute_dtype).contiguous() + k = k.to(compute_dtype).contiguous() + freqs_complex = freqs_complex.contiguous() + return q, k, freqs_complex, compute_dtype + + +@triton.jit +def _triton_llama4_rope_npu( + q_ptr, + k_ptr, + freqs_complex_ptr, + q_row_stride, + k_row_stride, + q_head_stride, + k_head_stride, + freqs_row_stride, + sl, + bs: tl.constexpr, + n_qh: tl.constexpr, + n_kh: tl.constexpr, + hd: tl.constexpr, + BLOCK_Q: tl.constexpr, + BLOCK_K: tl.constexpr, + imag_sign: tl.constexpr, +): + """ + Llama4 RoPE on Ascend NPU for interleaved complex layout: + - q/k shape: (bs, sl, n_heads, hd) + - freqs_complex_ptr: (sl, hd//2, 2) + """ + pid = tl.program_id(0).to(tl.int64) + batch_idx = pid // sl + seq_idx = pid % sl + + if batch_idx >= bs: + return + + q_base = q_ptr + pid * q_row_stride + k_base = k_ptr + pid * k_row_stride + + freq_base = seq_idx * freqs_row_stride + hd_idx = tl.arange(0, hd) + hd_mask = hd_idx < (hd) + + freq_idx = tl.arange(0, hd) + freq_mask = freq_idx < (hd) + + freqs_complex = tl.load(freqs_complex_ptr + freq_base + freq_idx, mask=freq_mask, other=0.0) + + freqs_complex = freqs_complex.reshape(hd // 2, 2, can_reorder=True) + freqs_real, freqs_imag = tl.split(freqs_complex) + freqs_imag = freqs_imag * imag_sign + + # Q heads (chunked for UB) + for qh_block in range(0, n_qh, BLOCK_Q): + qh_idx = tl.arange(0, BLOCK_Q) + qh_block + qh_mask = qh_idx < n_qh + block_mask = qh_mask[:, None] & hd_mask[None, :] + + head_ptr = q_base + qh_idx[:, None] * q_head_stride + + q_pair = tl.load( + head_ptr + hd_idx[None, :], + mask=block_mask, + other=0.0, + ) + q_pair = q_pair.reshape(BLOCK_Q, hd // 2, 2, can_reorder=True) + q_real, q_imag = tl.split(q_pair) + + new_real = tl.math.fma(q_real, freqs_real, -(q_imag * freqs_imag)) + new_imag = tl.math.fma(q_real, freqs_imag, q_imag * freqs_real) + + pair_idx = tl.arange(0, hd // 2) + real_idx = pair_idx * 2 + imag_idx = pair_idx * 2 + 1 + + pair_mask = pair_idx < (hd // 2) + + real_mask = qh_mask[:, None] & pair_mask[None, :] + imag_mask = qh_mask[:, None] & pair_mask[None, :] + + # store real + tl.store( + head_ptr + real_idx[None, :], + new_real, + mask=real_mask, + ) + + # store imag + tl.store( + head_ptr + imag_idx[None, :], + new_imag, + mask=imag_mask, + ) + + # K heads (chunked for UB) + for kh_block in range(0, n_kh, BLOCK_K): + kh_idx = tl.arange(0, BLOCK_K) + kh_block + kh_mask = kh_idx < n_kh + block_mask = kh_mask[:, None] & hd_mask[None, :] + + head_ptr = k_base + kh_idx[:, None] * k_head_stride + + k_pair = tl.load( + head_ptr + hd_idx[None, :], + mask=block_mask, + other=0.0, + ) + + k_pair = k_pair.reshape(BLOCK_K, hd // 2, 2, can_reorder=True) + k_real, k_imag = tl.split(k_pair) + + new_real = tl.math.fma(k_real, freqs_real, -(k_imag * freqs_imag)) + new_imag = tl.math.fma(k_real, freqs_imag, k_imag * freqs_real) + + pair_idx = tl.arange(0, hd // 2) + real_idx = pair_idx * 2 + imag_idx = pair_idx * 2 + 1 + + pair_mask = pair_idx < (hd // 2) + + real_mask = kh_mask[:, None] & pair_mask[None, :] + imag_mask = kh_mask[:, None] & pair_mask[None, :] + + # store real + tl.store( + head_ptr + real_idx[None, :], + new_real, + mask=real_mask, + ) + + # store imag + tl.store( + head_ptr + imag_idx[None, :], + new_imag, + mask=imag_mask, + ) + + +def llama4_rope_forward(q, k, freqs_cis): + """ + Ascend NPU implementation of Llama4 RoPE. + + q/k: (bs, sl, n_heads, hd) with interleaved complex last-dim layout. + freqs_cis: complex (..., hd//2) OR packed (..., 2*(hd//2)). + """ + original_dtype = q.dtype + + bs, sl, n_qh, hd = q.shape + _, _, n_kh, _ = k.shape + if hd % 2 != 0: + raise ValueError(f"head_dim must be even for interleaved complex layout, got {hd}") + + if freqs_cis.is_complex(): + freqs_cis = freqs_cis.reshape(-1, freqs_cis.shape[-1]) + if freqs_cis.shape[0] > sl: + freqs_cis = freqs_cis[:sl] + freqs_cis = torch.view_as_real(freqs_cis) + + q, k, freqs_cis, compute_dtype = _cast_and_contiguous(q, k, freqs_cis) + + # UB tiling strategy: tile heads dimension only + dtype_size = q.element_size() + shapes = ((n_qh, hd), (n_kh, hd)) + tile_shapes = compute_default_tiling_strategy( + safety_margin=0.90, + dtype_size=dtype_size, + memory_multiplier=20.0, + shapes=shapes, + tiling_dims=(0, 0), + ) + + if tile_shapes is not None and len(tile_shapes) == len(shapes): + q_tile_shape, k_tile_shape = tile_shapes + BLOCK_Q, _ = q_tile_shape + BLOCK_K, _ = k_tile_shape + BLOCK_Q = max(BLOCK_Q, 2) + BLOCK_K = max(BLOCK_K, 2) + else: + BLOCK_Q = triton.next_power_of_2(n_qh) + BLOCK_K = triton.next_power_of_2(n_kh) + + n_row = bs * sl + + _triton_llama4_rope_npu[(n_row,)]( + q, + k, + freqs_cis, + q.stride(1), + k.stride(1), + q.stride(2), + k.stride(2), + freqs_cis.stride(0), + sl, + bs, + n_qh, + n_kh, + hd, + BLOCK_Q, + BLOCK_K, + imag_sign=1.0, + ) + + if compute_dtype != original_dtype: + q = q.to(original_dtype) + k = k.to(original_dtype) + return q, k + + +def llama4_rope_backward(dq, dk, freqs_cis): + """ + Ascend NPU implementation of Llama4 RoPE. + + q/k: (bs, sl, n_heads, hd) with interleaved complex last-dim layout. + freqs_cis: complex (..., hd//2) OR packed (..., 2*(hd//2)). + """ + original_dtype = dq.dtype + + bs, sl, n_qh, hd = dq.shape + _, _, n_kh, _ = dk.shape + if hd % 2 != 0: + raise ValueError(f"head_dim must be even for interleaved complex layout, got {hd}") + + if freqs_cis.is_complex(): + freqs_cis = freqs_cis.reshape(-1, freqs_cis.shape[-1]) + if freqs_cis.shape[0] > sl: + freqs_cis = freqs_cis[:sl] + freqs_cis = torch.view_as_real(freqs_cis) + + dq, dk, freqs_cis, compute_dtype = _cast_and_contiguous(dq, dk, freqs_cis) + + # UB tiling strategy: tile heads dimension only + dtype_size = dq.element_size() + shapes = ((n_qh, hd), (n_kh, hd)) + tile_shapes = compute_default_tiling_strategy( + safety_margin=0.90, + dtype_size=dtype_size, + memory_multiplier=20.0, + shapes=shapes, + tiling_dims=(0, 0), + ) + + if tile_shapes is not None and len(tile_shapes) == len(shapes): + q_tile_shape, k_tile_shape = tile_shapes + BLOCK_Q, _ = q_tile_shape + BLOCK_K, _ = k_tile_shape + BLOCK_Q = max(BLOCK_Q, 2) + BLOCK_K = max(BLOCK_K, 2) + else: + BLOCK_Q = triton.next_power_of_2(n_qh) + BLOCK_K = triton.next_power_of_2(n_kh) + + n_row = bs * sl + + _triton_llama4_rope_npu[(n_row,)]( + dq, + dk, + freqs_cis, + dq.stride(1), + dk.stride(1), + dq.stride(2), + dk.stride(2), + freqs_cis.stride(0), + sl, + bs, + n_qh, + n_kh, + hd, + BLOCK_Q, + BLOCK_K, + imag_sign=-1.0, + ) + + if compute_dtype != original_dtype: + dq = dq.to(original_dtype) + dk = dk.to(original_dtype) + return dq, dk + + +class LigerLlama4RopeFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, freqs_cis, BLOCK_SIZE: int = None): + # BLOCK_SIZE is ignored for Ascend (we auto-tile heads by UB), kept for API compatibility + q_out, k_out = llama4_rope_forward(q, k, freqs_cis) + ctx.save_for_backward(freqs_cis.detach() if isinstance(freqs_cis, torch.Tensor) else freqs_cis) + return q_out, k_out + + @staticmethod + def backward(ctx, dq, dk): + (freqs_cis,) = ctx.saved_tensors + dq_out, dk_out = llama4_rope_backward(dq, dk, freqs_cis) + return dq_out, dk_out, None, None diff --git a/src/liger_kernel/ops/backends/_ascend/ops/poly_norm.py b/src/liger_kernel/ops/backends/_ascend/ops/poly_norm.py new file mode 100755 index 0000000000000000000000000000000000000000..d4deb2329a177da52deb2214487242774a71feee --- /dev/null +++ b/src/liger_kernel/ops/backends/_ascend/ops/poly_norm.py @@ -0,0 +1,786 @@ +import torch +import triton +import triton.language as tl + +from triton.language.math import rsqrt + +from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy +from liger_kernel.ops.utils import ensure_contiguous +from liger_kernel.ops.utils import get_npu_core_count + +# ----------------------------------------------------------------------------- +# Forward Kernel - No Tiling (for n_cols <= 2048) +# ----------------------------------------------------------------------------- + + +@triton.jit +def _poly_norm_forward_kernel_no_tiling( + Y_ptr, + Y_row_stride, + X_ptr, + X_row_stride, + W_ptr, # weight: [3] for [w0, w1, w2] + B_ptr, # bias: scalar + RSTD_ptr, # cache rstd for backward: shape (n_rows, 3) + RSTD_row_stride, + n_rows: tl.constexpr, + n_cols: tl.constexpr, + eps, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + """ + NPU-optimized PolyNorm forward kernel for small n_cols (<= 2048). + + PolyNorm formula: + y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b + where norm(u) = u / sqrt(mean(u²) + ε) + + """ + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + + # Grid-stride loop setup + grid_stride = num_progs * BLOCK_SIZE_M + num_iterations = tl.cdiv(n_rows, grid_stride) + + col_offsets = tl.arange(0, BLOCK_SIZE_N) + col_mask = col_offsets < n_cols + row_offsets = tl.arange(0, BLOCK_SIZE_M) + + # Load weights and bias + w0 = tl.load(W_ptr + 0) + w1 = tl.load(W_ptr + 1) + w2 = tl.load(W_ptr + 2) + b = tl.load(B_ptr) + + # Grid-stride loop over row blocks + for i in range(num_iterations): + row_idx = i * grid_stride + pid * BLOCK_SIZE_M + row_offsets + row_mask = row_idx < n_rows + block_mask = row_mask[:, None] & col_mask[None, :] + + # Load input rows + X_rows = tl.load( + X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :], + mask=block_mask, + other=0.0, + cache_modifier=".cg", + ) + + X_f32 = X_rows.to(tl.float32) + + # Compute x³, x², x + X_pow3 = X_f32 * X_f32 * X_f32 + X_pow2 = X_f32 * X_f32 + X_pow1 = X_f32 + + # Compute norm(x³): norm(u) = u * rsqrt(mean(u²) + eps) + # Mask out out-of-bounds positions to prevent contaminating the sum + mean_square_3 = tl.sum(X_pow3 * X_pow3, axis=1) / n_cols + rstd_3 = rsqrt(mean_square_3 + eps) + norm_x3 = X_pow3 * rstd_3[:, None] + + # Compute norm(x²) + mean_square_2 = tl.sum(X_pow2 * X_pow2, axis=1) / n_cols + rstd_2 = rsqrt(mean_square_2 + eps) + norm_x2 = X_pow2 * rstd_2[:, None] + + # Compute norm(x) + mean_square_1 = tl.sum(X_pow1 * X_pow1, axis=1) / n_cols + rstd_1 = rsqrt(mean_square_1 + eps) + norm_x1 = X_pow1 * rstd_1[:, None] + + # Cache rstd values for backward (store 3 values per row) + tl.store(RSTD_ptr + row_idx * RSTD_row_stride + 0, rstd_3.to(X_rows.dtype), mask=row_mask) + tl.store(RSTD_ptr + row_idx * RSTD_row_stride + 1, rstd_2.to(X_rows.dtype), mask=row_mask) + tl.store(RSTD_ptr + row_idx * RSTD_row_stride + 2, rstd_1.to(X_rows.dtype), mask=row_mask) + + # Compute output: y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b + Y_f32 = w0 * norm_x3 + w1 * norm_x2 + w2 * norm_x1 + b + + # Store output + tl.store( + Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :], + Y_f32.to(X_rows.dtype), + mask=block_mask, + ) + + +# ----------------------------------------------------------------------------- +# Forward Kernel - With Tiling (for n_cols > 2048) +# ----------------------------------------------------------------------------- + + +@triton.jit +def _poly_norm_forward_kernel_npu( + Y_ptr, + Y_row_stride, + X_ptr, + X_row_stride, + W_ptr, # weight: [3] for [w0, w1, w2] + B_ptr, # bias: scalar + RSTD_ptr, # cache rstd for backward: shape (n_rows, 3) + RSTD_row_stride, + n_rows: tl.constexpr, + n_cols: tl.constexpr, + eps, + BLOCK_SIZE: tl.constexpr, +): + """ + NPU-optimized PolyNorm forward kernel with column blocking. + + This kernel processes rows using a grid-stride loop pattern: + 1. Each program handles multiple rows + 2. For each row, we process it in column chunks of BLOCK_SIZE + 3. Grid size is limited to NPU core count to avoid resource overflow + + Three-pass algorithm per row: + - First pass: compute mean_square and rstd for x³, x², x across all column blocks + - Second pass: apply normalization and affine transformation + + PolyNorm formula: + y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b + where norm(u) = u / sqrt(mean(u²) + ε) + """ + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + num_col_blocks = tl.cdiv(n_cols, BLOCK_SIZE) + + offsets = tl.arange(0, BLOCK_SIZE) + + # Load weights and bias + w0 = tl.load(W_ptr + 0) + w1 = tl.load(W_ptr + 1) + w2 = tl.load(W_ptr + 2) + b = tl.load(B_ptr) + + # Grid-stride loop over rows + for row_idx in range(pid, n_rows, num_progs): + Y_row_ptr = Y_ptr + row_idx * Y_row_stride + X_row_ptr = X_ptr + row_idx * X_row_stride + RSTD_row_ptr = RSTD_ptr + row_idx * RSTD_row_stride + + # First pass: compute mean_square for x³, x², x + sum_square_3 = 0.0 + sum_square_2 = 0.0 + sum_square_1 = 0.0 + + for col_block_idx in range(num_col_blocks): + col_start = col_block_idx * BLOCK_SIZE + col_offsets = col_start + offsets + mask = col_offsets < n_cols + + X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32) + + # Compute powers + X_pow3 = X_block * X_block * X_block + X_pow2 = X_block * X_block + X_pow1 = X_block + + sum_square_3 += tl.sum(X_pow3 * X_pow3) + sum_square_2 += tl.sum(X_pow2 * X_pow2) + sum_square_1 += tl.sum(X_pow1 * X_pow1) + + # Compute rstd values + mean_square_3 = sum_square_3 / n_cols + mean_square_2 = sum_square_2 / n_cols + mean_square_1 = sum_square_1 / n_cols + + rstd_3 = rsqrt(mean_square_3 + eps) + rstd_2 = rsqrt(mean_square_2 + eps) + rstd_1 = rsqrt(mean_square_1 + eps) + + # Store rstd values + tl.store(RSTD_row_ptr + 0, rstd_3) + tl.store(RSTD_row_ptr + 1, rstd_2) + tl.store(RSTD_row_ptr + 2, rstd_1) + + # Second pass: normalize and apply affine transformation + for col_block_idx in range(num_col_blocks): + col_start = col_block_idx * BLOCK_SIZE + col_offsets = col_start + offsets + mask = col_offsets < n_cols + + # Load input + X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0, cache_modifier=".ca").to(tl.float32) + + # Compute powers + X_pow3 = X_block * X_block * X_block + X_pow2 = X_block * X_block + X_pow1 = X_block + + # Apply normalization + norm_x3 = X_pow3 * rstd_3 + norm_x2 = X_pow2 * rstd_2 + norm_x1 = X_pow1 * rstd_1 + + # Compute output: y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b + Y_f32 = w0 * norm_x3 + w1 * norm_x2 + w2 * norm_x1 + b + + # Store result + tl.store(Y_row_ptr + col_offsets, Y_f32.to(X_block.dtype), mask=mask) + + +# ----------------------------------------------------------------------------- +# Backward Kernel - No Tiling (for n_cols <= 2048) +# ----------------------------------------------------------------------------- + + +@triton.jit +def _poly_norm_backward_kernel_no_tiling( + dY_ptr, + dY_row_stride, + dX_ptr, + dX_row_stride, + X_ptr, + X_row_stride, + W_ptr, + RSTD_ptr, + RSTD_row_stride, + dW_scratch_ptr, # shape: (n_programs, 3) + dW_scratch_stride, + dB_scratch_ptr, # shape: (n_programs,) + n_rows: tl.constexpr, + n_cols: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + """ + NPU-optimized PolyNorm backward kernel for small n_cols (<= 2048). + + Backward pass equations: + ∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)] + + where: + - D_p = RMS(x^p) = 1/rstd_p + - S_p = sum(grad * x^p) over the row + - d = n_cols + - p ∈ {3, 2, 1} + """ + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + + # Grid-stride loop setup + grid_stride = num_progs * BLOCK_SIZE_M + num_iterations = tl.cdiv(n_rows, grid_stride) + + col_offsets = tl.arange(0, BLOCK_SIZE_N) + col_mask = col_offsets < n_cols + row_offsets = tl.arange(0, BLOCK_SIZE_M) + + # Load weights + w0 = tl.load(W_ptr + 0).to(tl.float32) + w1 = tl.load(W_ptr + 1).to(tl.float32) + w2 = tl.load(W_ptr + 2).to(tl.float32) + + # Each program accumulates its own dW/dB contribution to avoid atomic contention + dW0_acc = 0.0 + dW1_acc = 0.0 + dW2_acc = 0.0 + dB_acc = 0.0 + + # Grid-stride loop over row blocks + for i in range(num_iterations): + row_idx = i * grid_stride + pid * BLOCK_SIZE_M + row_offsets + row_mask = row_idx < n_rows + block_mask = row_mask[:, None] & col_mask[None, :] + + # Load input and gradient data + X_rows = tl.load( + X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :], + mask=block_mask, + other=0.0, + cache_modifier=".cg", + ) + dY_rows = tl.load( + dY_ptr + row_idx[:, None] * dY_row_stride + col_offsets[None, :], + mask=block_mask, + other=0.0, + cache_modifier=".cg", + ) + + # Load cached rstd values (3 values per row) + rstd_3 = tl.load(RSTD_ptr + row_idx * RSTD_row_stride + 0, mask=row_mask, other=0.0).to(tl.float32) + rstd_2 = tl.load(RSTD_ptr + row_idx * RSTD_row_stride + 1, mask=row_mask, other=0.0).to(tl.float32) + rstd_1 = tl.load(RSTD_ptr + row_idx * RSTD_row_stride + 2, mask=row_mask, other=0.0).to(tl.float32) + + X_f32 = X_rows.to(tl.float32) + dY_f32 = dY_rows.to(tl.float32) + + # Compute powers + X_pow3 = X_f32 * X_f32 * X_f32 + X_pow2 = X_f32 * X_f32 + X_pow1 = X_f32 + + # Accumulate bias gradient: dB = sum(dY) + dB_acc += tl.sum(dY_f32) + + # Compute gradient w.r.t. input using closed-form formula + # For p=3: ∂L/∂x from w0 * norm(x³) + S_3 = tl.sum(dY_f32 * X_pow3, axis=1) # sum over columns for each row + grad_x_3 = w0 * ( + 3.0 * X_pow2 * rstd_3[:, None] * dY_f32 + - (3.0 / n_cols) * X_pow2 * X_pow3 * (rstd_3[:, None] * rstd_3[:, None] * rstd_3[:, None]) * S_3[:, None] + ) + + # For p=2: ∂L/∂x from w1 * norm(x²) + S_2 = tl.sum(dY_f32 * X_pow2, axis=1) + grad_x_2 = w1 * ( + 2.0 * X_pow1 * rstd_2[:, None] * dY_f32 + - (2.0 / n_cols) * X_pow1 * X_pow2 * (rstd_2[:, None] * rstd_2[:, None] * rstd_2[:, None]) * S_2[:, None] + ) + + # For p=1: ∂L/∂x from w2 * norm(x) + S_1 = tl.sum(dY_f32 * X_pow1, axis=1) + grad_x_1 = w2 * ( + 1.0 * rstd_1[:, None] * dY_f32 + - (1.0 / n_cols) * X_pow1 * (rstd_1[:, None] * rstd_1[:, None] * rstd_1[:, None]) * S_1[:, None] + ) + + # Total gradient + dX_f32 = grad_x_3 + grad_x_2 + grad_x_1 + + # Store dX + tl.store( + dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :], + dX_f32.to(X_ptr.dtype.element_ty), + mask=block_mask, + ) + + # Accumulate weight gradients using closed-form: dW_p = rstd_p * S_p + dW0_acc += tl.sum(rstd_3 * S_3) + dW1_acc += tl.sum(rstd_2 * S_2) + dW2_acc += tl.sum(rstd_1 * S_1) + + # Write this program's accumulated dW/dB to its dedicated scratch row + tl.store(dW_scratch_ptr + pid * dW_scratch_stride + 0, dW0_acc) + tl.store(dW_scratch_ptr + pid * dW_scratch_stride + 1, dW1_acc) + tl.store(dW_scratch_ptr + pid * dW_scratch_stride + 2, dW2_acc) + tl.store(dB_scratch_ptr + pid, dB_acc) + + +# ----------------------------------------------------------------------------- +# Backward Kernel - With Tiling (for n_cols > 2048) +# ----------------------------------------------------------------------------- + + +@triton.jit +def _poly_norm_backward_kernel_npu( + dY_ptr, + dY_row_stride, + dX_ptr, + dX_row_stride, + X_ptr, + X_row_stride, + W_ptr, + RSTD_ptr, + RSTD_row_stride, + dW_ptr, + dB_ptr, + n_rows: tl.constexpr, + n_cols: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + NPU-optimized PolyNorm backward kernel with column blocking. + + Each program processes multiple rows using grid-stride loop. + For each row, we process columns in blocks to avoid UB overflow. + + Two-pass algorithm: + - First pass: compute S_p = sum(grad * x^p) for p ∈ {3, 2, 1} + - Second pass: compute gradients dX, dW, dB + """ + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + num_col_blocks = tl.cdiv(n_cols, BLOCK_SIZE) + + offsets = tl.arange(0, BLOCK_SIZE) + + # Load weights + w0 = tl.load(W_ptr + 0).to(tl.float32) + w1 = tl.load(W_ptr + 1).to(tl.float32) + w2 = tl.load(W_ptr + 2).to(tl.float32) + + dw0_acc = 0.0 + dw1_acc = 0.0 + dw2_acc = 0.0 + db_acc = 0.0 + + # Grid-stride loop over rows + for row_idx in range(pid, n_rows, num_progs): + dY_row_ptr = dY_ptr + row_idx * dY_row_stride + X_row_ptr = X_ptr + row_idx * X_row_stride + dX_row_ptr = dX_ptr + row_idx * dX_row_stride + RSTD_row_ptr = RSTD_ptr + row_idx * RSTD_row_stride + + # Load cached rstd values + rstd_3 = tl.load(RSTD_row_ptr + 0).to(tl.float32) + rstd_2 = tl.load(RSTD_row_ptr + 1).to(tl.float32) + rstd_1 = tl.load(RSTD_row_ptr + 2).to(tl.float32) + + # First pass: compute S_p = sum(grad * x^p) + S_3 = 0.0 + S_2 = 0.0 + S_1 = 0.0 + + for col_block_idx in range(num_col_blocks): + col_start = col_block_idx * BLOCK_SIZE + col_offsets = col_start + offsets + mask = col_offsets < n_cols + + X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + dY_block = tl.load(dY_row_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + + # Compute powers + X_pow3 = X_block * X_block * X_block + X_pow2 = X_block * X_block + X_pow1 = X_block + + S_3 += tl.sum(dY_block * X_pow3) + S_2 += tl.sum(dY_block * X_pow2) + S_1 += tl.sum(dY_block * X_pow1) + + # Second pass: compute gradients + for col_block_idx in range(num_col_blocks): + col_start = col_block_idx * BLOCK_SIZE + col_offsets = col_start + offsets + mask = col_offsets < n_cols + + X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + dY_block = tl.load(dY_row_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + + # Compute powers + X_pow3 = X_block * X_block * X_block + X_pow2 = X_block * X_block + X_pow1 = X_block + + # Compute gradient w.r.t. input using closed-form formula + # For p=3: ∂L/∂x from w0 * norm(x³) + grad_x_3 = w0 * ( + 3.0 * X_pow2 * rstd_3 * dY_block - (3.0 / n_cols) * X_pow2 * X_pow3 * (rstd_3 * rstd_3 * rstd_3) * S_3 + ) + + # For p=2: ∂L/∂x from w1 * norm(x²) + grad_x_2 = w1 * ( + 2.0 * X_pow1 * rstd_2 * dY_block - (2.0 / n_cols) * X_pow1 * X_pow2 * (rstd_2 * rstd_2 * rstd_2) * S_2 + ) + + # For p=1: ∂L/∂x from w2 * norm(x) + grad_x_1 = w2 * (1.0 * rstd_1 * dY_block - (1.0 / n_cols) * X_pow1 * (rstd_1 * rstd_1 * rstd_1) * S_1) + + # Total gradient + dX_block = grad_x_3 + grad_x_2 + grad_x_1 + + # Store dX + tl.store(dX_row_ptr + col_offsets, dX_block.to(X_ptr.dtype.element_ty), mask=mask) + + dw0_acc += tl.sum(rstd_3 * dY_block * X_pow3) + dw1_acc += tl.sum(rstd_2 * dY_block * X_pow2) + dw2_acc += tl.sum(rstd_1 * dY_block * X_pow1) + db_acc += tl.sum(dY_block) + + tl.store(dW_ptr + 0, dw0_acc) + tl.store(dW_ptr + 1, dw1_acc) + tl.store(dW_ptr + 2, dw2_acc) + tl.store(dB_ptr, db_acc) + + +# ----------------------------------------------------------------------------- +# Helper Functions +# ----------------------------------------------------------------------------- + + +def get_optimal_block_size(n_cols, is_forward: bool): + """ + Calculate optimal block size using compute_default_tiling_strategy. + + Memory analysis for forward pass (per row): + - Load: X_block (1 block) + - Compute: X_pow3, X_pow2, X_pow1, norm_x3, norm_x2, norm_x1 (6 blocks) + - Total: conservative estimate 8 blocks of memory + + Memory analysis for backward pass (per row): + - Load: X_block, dY_block, RSTD (3 blocks) + - Compute: X_pow3, X_pow2, X_pow1, grad_x_3, grad_x_2, grad_x_1 (6 blocks) + - Total: conservative estimate 10 blocks of memory + + Args: + n_cols: Number of columns in the tensor + is_forward: Whether this is for forward pass (True) or backward pass (False) + + Returns: + Optimal block size + """ + if n_cols <= 2048: + return triton.next_power_of_2(n_cols) + + memory_multiplier = 8.0 if is_forward else 10.0 + + tile_shapes = compute_default_tiling_strategy( + safety_margin=0.8, + dtype_size=4, + memory_multiplier=memory_multiplier, + shapes=((n_cols,),), + tiling_dims=(0,), + ) + + if tile_shapes and len(tile_shapes) > 0: + block_size = tile_shapes[0][0] + return max(2048, block_size) + else: + return 2048 + + +def _compute_grid_size(n_rows: int, block_size_m: int, num_cores: int) -> int: + """ + Compute the effective grid size for no-tiling kernels. + + Limits the grid to the minimum of: + - The number of row blocks actually needed (ceil(n_rows / BLOCK_SIZE_M)), which + prevents launching idle programs that would waste core cycles + - NPU core count, which is the hardware concurrency upper bound + + Args: + n_rows: Total number of rows to process + block_size_m: Number of rows each program handles per iteration + num_cores: Number of available NPU cores + + Returns: + Effective grid size + """ + num_row_blocks = triton.cdiv(n_rows, block_size_m) + return min(num_cores, num_row_blocks) + + +# ----------------------------------------------------------------------------- +# Forward and Backward Functions +# ----------------------------------------------------------------------------- + + +def poly_norm_forward(X, W, B, eps=1e-6): + """ + PolyNorm Forward Pass + + Args: + X: input tensor of shape (*, H) where H is hidden dimension + W: weight tensor of shape (3,) for [w0, w1, w2] + B: bias scalar tensor + eps: epsilon for numerical stability + + Returns: + Y: output tensor of same shape as X + X: reshaped input (for backward) + RSTD: cached rstd values (for backward) + BLOCK_SIZE: block size used + """ + shape = X.shape + dim = shape[-1] + X = X.view(-1, dim) + n_rows, n_cols = X.shape + + # Check constraints + assert W.shape[0] == 3, "Weight tensor must have shape (3,)" + assert B.numel() == 1, "Bias must be a scalar" + + # Get optimal block sizes + BLOCK_SIZE = get_optimal_block_size(n_cols, True) + BLOCK_SIZE_M = 2048 // BLOCK_SIZE + + # RSTD is to cache rstd for each row (3 values per row) + Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) + RSTD = torch.empty((n_rows, 3), dtype=torch.float32, device=X.device) + + # Grid size + num_cores = get_npu_core_count() + + # Choose kernel based on n_cols + if n_cols <= 2048: + # Small kernel: use 2D tensor loading + grid_size = _compute_grid_size(n_rows, BLOCK_SIZE_M, num_cores) + + _poly_norm_forward_kernel_no_tiling[(grid_size,)]( + Y, + Y.stride(0), + X, + X.stride(0), + W, + B, + RSTD, + RSTD.stride(0), + n_rows, + n_cols, + eps, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE, + ) + else: + # Large kernel: use column blocking + grid_size = min(num_cores, n_rows) + + _poly_norm_forward_kernel_npu[(grid_size,)]( + Y, + Y.stride(0), + X, + X.stride(0), + W, + B, + RSTD, + RSTD.stride(0), + n_rows, + n_cols, + eps, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return Y.view(*shape), X, RSTD + + +def poly_norm_backward(dY, X, W, RSTD, in_place): + """ + PolyNorm Backward Pass + + Args: + dY: gradient of output + X: input tensor (already reshaped to 2D) + W: weight tensor + RSTD: cached rstd values from forward + BLOCK_SIZE: block size from forward + in_place: whether to in-place modify dY to store dX (saves memory) + + Returns: + dX: gradient w.r.t. input + dW: gradient w.r.t. weight + dB: gradient w.r.t. bias + """ + shape = dY.shape + dim = shape[-1] + dY = dY.view(-1, dim) + n_rows, n_cols = dY.shape + + # Get optimal block sizes + BLOCK_SIZE_BACKWARD = get_optimal_block_size(n_cols, False) + BLOCK_SIZE_M = 2048 // BLOCK_SIZE_BACKWARD + + # Grid size + num_cores = get_npu_core_count() + + # Allocate or reuse gradients + if in_place is True: + dX = dY + else: + dX = torch.zeros_like(dY) + + # Choose kernel based on n_cols + if n_cols <= 2048: + # Small kernel: use 2D tensor loading with scratch buffers + grid_size = _compute_grid_size(n_rows, BLOCK_SIZE_M, num_cores) + + # Allocate per-program scratch buffers for dW and dB + dW_scratch = torch.empty((grid_size, 3), dtype=torch.float32, device=W.device) + dB_scratch = torch.empty((grid_size,), dtype=torch.float32, device=W.device) + + _poly_norm_backward_kernel_no_tiling[(grid_size,)]( + dY, + dY.stride(0), + dX, + dX.stride(0), + X, + X.stride(0), + W, + RSTD, + RSTD.stride(0), + dW_scratch, + dW_scratch.stride(0), + dB_scratch, + n_rows, + n_cols, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_BACKWARD, + ) + + dW = dW_scratch.sum(dim=0).to(W.dtype) + dB = dB_scratch.sum().to(W.dtype) + else: + # Large kernel: use column blocking with atomic operations + grid_size = min(num_cores, n_rows) + + dW = torch.zeros(3, dtype=torch.float32, device=W.device) + dB = torch.zeros(1, dtype=torch.float32, device=W.device) + + _poly_norm_backward_kernel_npu[(grid_size,)]( + dY, + dY.stride(0), + dX, + dX.stride(0), + X, + X.stride(0), + W, + RSTD, + RSTD.stride(0), + dW, + dB, + n_rows, + n_cols, + BLOCK_SIZE=BLOCK_SIZE_BACKWARD, + ) + + dW = dW.to(W.dtype) + dB = dB.squeeze().to(W.dtype) + + # Reshape dX back to original shape + dX = dX.view(*shape) + + return dX, dW, dB + + +# ----------------------------------------------------------------------------- +# Autograd Function +# ----------------------------------------------------------------------------- + + +class LigerPolyNormFunction(torch.autograd.Function): + """ + PolyNorm Function with forward and backward pass + + PolyNorm formula: + y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b + where norm(u) = u / sqrt(mean(u²) + ε) + + Backward uses closed-form gradient: + ∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)] + """ + + @staticmethod + @ensure_contiguous + def forward(ctx, X, W, B, eps=1e-6, in_place=True): + """ + Args: + X: input tensor of shape (B, T, H) or (BxT, H) + W: weight tensor of shape (3,) for [w0, w1, w2] + B: bias scalar + eps: epsilon for numerical stability + in_place: whether to in-place modify grad_output in backward (saves memory) + + Returns: + Y: output tensor of same shape as X + """ + Y, X, RSTD = poly_norm_forward(X, W, B, eps) + ctx.in_place = in_place + ctx.save_for_backward(X, W, RSTD) + return Y + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output): + """ + Args: + grad_output: gradient of output + + Returns: + dX, dW, dB: gradients w.r.t. X, W, B + """ + X, W, RSTD = ctx.saved_tensors + dX, dW, dB = poly_norm_backward(grad_output, X, W, RSTD, ctx.in_place) + return dX, dW, dB, None, None diff --git a/src/liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py b/src/liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py new file mode 100755 index 0000000000000000000000000000000000000000..d273b7ec972e48e210819a9d9b45745524513554 --- /dev/null +++ b/src/liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py @@ -0,0 +1,272 @@ +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy +from liger_kernel.ops.utils import get_npu_core_count + + +@triton.jit +def _triton_qwen2vl_mrope_npu( + q_ptr, + q_row_stride, + k_ptr, + k_row_stride, + cos, + sin, + sl, + bs: tl.constexpr, + total_rows: tl.constexpr, + n_qh: tl.constexpr, + n_kh: tl.constexpr, + hd: tl.constexpr, + mrope_section_t: tl.constexpr, + mrope_section_h: tl.constexpr, + BLOCK_Q: tl.constexpr, + BLOCK_K: tl.constexpr, + BACKWARD_PASS: tl.constexpr = False, +): + program_id = tl.program_id(0) + num_programs = tl.num_programs(0) + + rows_per_program = (total_rows + num_programs - 1) // num_programs + start_row = program_id * rows_per_program + actual_rows = tl.minimum(rows_per_program, total_rows - start_row) + + for row_offset in tl.range(0, actual_rows): + pid = start_row + row_offset + + t_end = mrope_section_t + h_end = t_end + mrope_section_h + + t_cos = cos + pid * hd + h_cos = t_cos + bs * sl * hd + w_cos = h_cos + bs * sl * hd + t_sin = sin + pid * hd + h_sin = t_sin + bs * sl * hd + w_sin = h_sin + bs * sl * hd + + q_base = q_ptr + pid * q_row_stride + k_base = k_ptr + pid * k_row_stride + + d_idx = tl.arange(0, hd // 2) + d_mask = d_idx < (hd // 2) + + pos_mask_t = d_idx < t_end + pos_mask_h = (d_idx >= t_end) & (d_idx < h_end) + + text_cos_vals = tl.load(t_cos + d_idx, mask=d_mask, other=0) + text_sin_vals = tl.load(t_sin + d_idx, mask=d_mask, other=0) + height_cos_vals = tl.load(h_cos + d_idx, mask=d_mask, other=0) + height_sin_vals = tl.load(h_sin + d_idx, mask=d_mask, other=0) + width_cos_vals = tl.load(w_cos + d_idx, mask=d_mask, other=0) + width_sin_vals = tl.load(w_sin + d_idx, mask=d_mask, other=0) + + cos_vals = tl.where(pos_mask_t, text_cos_vals, tl.where(pos_mask_h, height_cos_vals, width_cos_vals)) + sin_vals = tl.where(pos_mask_t, text_sin_vals, tl.where(pos_mask_h, height_sin_vals, width_sin_vals)) + + # Process q heads in chunks to prevent UB overflow + for qh_block in range(0, n_qh, BLOCK_Q): + qh_idx = tl.arange(0, BLOCK_Q) + qh_block + qh_mask = qh_idx < n_qh + + block_mask = qh_mask[:, None] & d_mask[None, :] + offsets = qh_idx[:, None] * hd + d_idx[None, :] + + q_left = tl.load(q_base + offsets, mask=block_mask, other=0) + q_right = tl.load(q_base + offsets + (hd // 2), mask=block_mask, other=0) + + if not BACKWARD_PASS: + new_left = q_left * cos_vals - q_right * sin_vals + new_right = q_right * cos_vals + q_left * sin_vals + else: + new_left = q_left * cos_vals + q_right * sin_vals + new_right = q_right * cos_vals - q_left * sin_vals + + tl.store(q_base + offsets, new_left, mask=block_mask) + tl.store(q_base + offsets + (hd // 2), new_right, mask=block_mask) + + # Process k heads in chunks to prevent UB overflow + for kh_block in range(0, n_kh, BLOCK_K): + kh_idx = tl.arange(0, BLOCK_K) + kh_block + kh_mask = kh_idx < n_kh + + block_mask = kh_mask[:, None] & d_mask[None, :] + offsets = kh_idx[:, None] * hd + d_idx[None, :] + + k_left = tl.load(k_base + offsets, mask=block_mask, other=0) + k_right = tl.load(k_base + offsets + (hd // 2), mask=block_mask, other=0) + + if not BACKWARD_PASS: + new_left = k_left * cos_vals - k_right * sin_vals + new_right = k_right * cos_vals + k_left * sin_vals + else: + new_left = k_left * cos_vals + k_right * sin_vals + new_right = k_right * cos_vals - k_left * sin_vals + + tl.store(k_base + offsets, new_left, mask=block_mask) + tl.store(k_base + offsets + (hd // 2), new_right, mask=block_mask) + + +def get_optimal_block_size_mrope(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size): + # MROPE forward tiling strategy: + # - cos_vals and sin_vals (include text, height and width) are loaded once outside loops (shared): (pad_hd // 2) * 6 = 3 * pad_hd elements each + # - In q heads loop (peak memory): + # * q_left: BLOCK_Q * (pad_hd // 2) elements + # * q_right: BLOCK_Q * (pad_hd // 2) elements + # * new_left: BLOCK_Q * (pad_hd // 2) elements (intermediate result) + # * new_right: BLOCK_Q * (pad_hd // 2) elements (intermediate result) + # * Total: 4 * BLOCK_Q * (pad_hd // 2) = 2 * BLOCK_Q * pad_hd elements + # - In k heads loop (peak memory): + # * k_left: BLOCK_K * (pad_hd // 2) elements + # * k_right: BLOCK_K * (pad_hd // 2) elements + # * new_left: BLOCK_K * (pad_hd // 2) elements (intermediate result) + # * new_right: BLOCK_K * (pad_hd // 2) elements (intermediate result) + # * Total: 4 * BLOCK_K * (pad_hd // 2) = 2 * BLOCK_K * pad_hd elements + # - Since q and k are processed separately, peak memory is max(BLOCK_Q, BLOCK_K) case + # - Plus shared cos/sin: 6 * (pad_hd // 2) = 3 * pad_hd elements + # - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + 3 * pad_hd) * dtype_size * 8 bits + # - Simplified: (2 * BLOCK_SIZE + 3) * pad_hd * dtype_size * 8 bits + # - For safety, use: memory_multiplier=3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits + # - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd)) + # - tiling_dims: (0, 0) means first dimension of each shape can be tiled + # - Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd)) + shapes = ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd)) + tile_shapes = compute_default_tiling_strategy( + safety_margin=0.90, + dtype_size=dtype_size, + memory_multiplier=3.0, + shapes=shapes, + tiling_dims=(0, 0), + ) + + if tile_shapes is not None and len(tile_shapes) == len(shapes): + # Strategy returns ((block_size_q, pad_hd), (block_size_kv, pad_hd)) + q_tile_shape, k_tile_shape = tile_shapes + BLOCK_Q, _ = q_tile_shape + BLOCK_K, _ = k_tile_shape + else: + # Fallback to conservative defaults + BLOCK_Q = 2048 + BLOCK_K = 2048 + + return BLOCK_Q, BLOCK_K + + +def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section): + # transpose it back to the physical shape because Triton looks at the physical storage + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + batch_size, seq_len, n_q_head, head_dim = q.shape + n_kv_head = k.shape[2] + pad_hd = triton.next_power_of_2(head_dim) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + + n_row = batch_size * seq_len + + # ensure tensors passed into the kernel are contiguous + q = q.contiguous() + k = k.contiguous() + cos = cos.contiguous() + sin = sin.contiguous() + + dtype_size = q.element_size() + BLOCK_Q, BLOCK_K = get_optimal_block_size_mrope(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size) + + num_cores = get_npu_core_count() + grid_size = min(num_cores, n_row) + + _triton_qwen2vl_mrope_npu[(grid_size,)]( + q, + q.stride(1), + k, + k.stride(1), + cos, + sin, + seq_len, + batch_size, + n_row, + n_q_head, + n_kv_head, + head_dim, + mrope_section[0], + mrope_section[1], + BLOCK_Q, + BLOCK_K, + BACKWARD_PASS=False, + ) + return q.transpose(1, 2), k.transpose(1, 2), cos, sin + + +def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section): + dq = dq.transpose(1, 2) + dk = dk.transpose(1, 2) + + batch_size, seq_len, n_q_head, head_dim = dq.shape + n_kv_head = dk.shape[2] + pad_hd = triton.next_power_of_2(head_dim) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + + n_row = batch_size * seq_len + + # ensure dq and dk are contiguous + dq = dq.contiguous() + dk = dk.contiguous() + + dtype_size = dq.element_size() + BLOCK_Q, BLOCK_K = get_optimal_block_size_mrope(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size) + + num_cores = get_npu_core_count() + grid_size = min(num_cores, n_row) + + _triton_qwen2vl_mrope_npu[(grid_size,)]( + dq, + dq.stride(1), + dk, + dk.stride(1), + cos, + sin, + seq_len, + batch_size, + n_row, + n_q_head, + n_kv_head, + head_dim, + mrope_section[0], + mrope_section[1], + BLOCK_Q, + BLOCK_K, + BACKWARD_PASS=True, + ) + return dq.transpose(1, 2), dk.transpose(1, 2) + + +class LigerQwen2VLMRopeFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """ + q size: (bsz, n_q_head, seq_len, head_dim) + k size: (bsz, n_kv_head, seq_len, head_dim) + cos size: (3, bsz, seq_len, head_dim) + sin size: (3, bsz, seq_len, head_dim) + """ + q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section) + ctx.save_for_backward(cos, sin) + ctx.mrope_section = mrope_section + return q, k + + @staticmethod + def backward(ctx, dq, dk): + """ + dq size: (bsz, n_q_head, seq_len, head_dim) + dk size: (bsz, n_kv_head, seq_len, head_dim) + cos size: (3, bsz, seq_len, head_dim) + sin size: (3, bsz, seq_len, head_dim) + """ + cos, sin = ctx.saved_tensors + mrope_section = ctx.mrope_section + dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section) + return dq, dk, None, None, None, None diff --git a/src/liger_kernel/ops/backends/_ascend/ops/rms_norm.py b/src/liger_kernel/ops/backends/_ascend/ops/rms_norm.py new file mode 100755 index 0000000000000000000000000000000000000000..f52fe40e16d58ffd716d307b1f037274a5c8408c --- /dev/null +++ b/src/liger_kernel/ops/backends/_ascend/ops/rms_norm.py @@ -0,0 +1,782 @@ +import torch +import triton +import triton.language as tl + +from triton.language.math import rsqrt + +from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy +from liger_kernel.ops.utils import ensure_contiguous +from liger_kernel.ops.utils import get_npu_core_count +from liger_kernel.ops.utils import torch_to_triton_dtype + +_CASTING_MODE_NONE: tl.constexpr = tl.constexpr(-1) +_CASTING_MODE_LLAMA: tl.constexpr = tl.constexpr(0) +_CASTING_MODE_GEMMA: tl.constexpr = tl.constexpr(1) + + +def torch_dtype_to_triton(dtype): + mapping = { + torch.float32: tl.float32, + torch.bfloat16: tl.bfloat16, + } + return mapping.get(dtype, tl.float32) + + +# ----------------------------------------------------------------------------- +# Forward Kernel - No Tiling (for n_cols <= 2048) +# ----------------------------------------------------------------------------- + + +@triton.jit +def _rms_norm_forward_kernel_no_tiling( + Y_ptr, + Y_row_stride, + X_ptr, + X_row_stride, + W_ptr, + RSTD_ptr, + RSTD_row_stride, + n_rows: tl.constexpr, + n_cols: tl.constexpr, + eps, + offset, + casting_mode: tl.constexpr, + elementwise_affine: tl.constexpr, + X_DTYPE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + """ + NPU-optimized rms_norm forward kernel for small n_cols (< 2048). + + Performance optimizations: + 1. Use 2D vector loading to maximize UB utilization (e.g., (1,2048), (2,1024), (4,512)) + 2. Process multiple rows at once using 2D indexing + 3. Keep data in registers, minimize conversions + 4. Use optimal cache policies + + Used when n_cols < 2048 to avoid the overhead of column blocking. + """ + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + + if casting_mode == _CASTING_MODE_NONE: + eps = eps.to(X_DTYPE) + offset = offset.to(X_DTYPE) + + # Grid-stride loop setup for 2D blocks + grid_stride = num_progs * BLOCK_SIZE_M + num_iterations = tl.cdiv(n_rows, grid_stride) + + col_offsets = tl.arange(0, BLOCK_SIZE_N) + col_mask = col_offsets < n_cols + row_offsets = tl.arange(0, BLOCK_SIZE_M) + + if elementwise_affine: + W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0) + + # Grid-stride loop over row blocks + for i in range(num_iterations): + row_idx = i * grid_stride + pid * BLOCK_SIZE_M + row_offsets + row_mask = row_idx < n_rows + block_mask = row_mask[:, None] & col_mask[None, :] + + # Load multiple rows at once using 2D indexing + X_rows = tl.load( + X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :], + mask=block_mask, + other=0.0, + eviction_policy="evict_first", + ) + + # Compute sum_square for all rows + if casting_mode == _CASTING_MODE_LLAMA or casting_mode == _CASTING_MODE_GEMMA: + X_rows = X_rows.to(tl.float32) + + sum_squares = tl.sum(tl.where(block_mask, X_rows * X_rows, 0.0), axis=1) + + # Compute rstd for all rows + mean_squares = sum_squares / n_cols + rstd_rows = rsqrt(mean_squares + eps) + + # Store rstd_rows + tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd_rows, mask=row_mask) + + # Apply casting based on mode + if casting_mode == _CASTING_MODE_GEMMA: + X_rows = X_rows.to(tl.float32) + if elementwise_affine: + W_row_fp32 = W_row.to(tl.float32) + elif casting_mode == _CASTING_MODE_LLAMA: + X_rows = X_rows.to(tl.float32) + + # Normalize + X_rows = X_rows * rstd_rows[:, None] + + # Cast back for Llama mode before weight multiplication + if casting_mode == _CASTING_MODE_LLAMA: + X_rows = X_rows.to(X_DTYPE) + + # Apply weight + if elementwise_affine: + if casting_mode == _CASTING_MODE_GEMMA: + Y_rows = X_rows * (offset + W_row_fp32[None, :]) + else: + Y_rows = X_rows * (offset + W_row[None, :]) + else: + Y_rows = X_rows + + # Cast back for Gemma mode + if casting_mode == _CASTING_MODE_GEMMA: + Y_rows = Y_rows.to(X_DTYPE) + + # Store results + tl.store(Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :], Y_rows, mask=block_mask) + + +# ----------------------------------------------------------------------------- +# Forward Kernel - With Tiling (for n_cols > 2048) +# ----------------------------------------------------------------------------- + + +@triton.jit +def _rms_norm_forward_kernel_tiled( + Y_ptr, + Y_row_stride, + X_ptr, + X_row_stride, + W_ptr, + RSTD_ptr, + RSTD_row_stride, + n_rows: tl.constexpr, + n_cols: tl.constexpr, + eps, + offset, + casting_mode: tl.constexpr, + elementwise_affine: tl.constexpr, + X_DTYPE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + NPU-optimized rms_norm forward kernel for large n_cols (>= 2048). + + This kernel processes rows using a grid-stride loop pattern: + 1. Each program handles multiple rows + 2. For each row, we process it in column chunks of BLOCK_SIZE + 3. Grid size is limited to NPU core count to avoid resource overflow + + This solves two problems: + 1. UB overflow when n_cols is too large (original kernel used n_cols as BLOCK_SIZE) + 2. Efficient multi-row processing within a single kernel launch + """ + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + num_col_blocks = tl.cdiv(n_cols, BLOCK_SIZE) + + if casting_mode == _CASTING_MODE_NONE: + eps = eps.to(X_DTYPE) + offset = offset.to(X_DTYPE) + + offsets = tl.arange(0, BLOCK_SIZE) + # Grid-stride loop over rows + for row_idx in tl.range(pid, n_rows, num_progs): + Y_row_ptr = Y_ptr + row_idx * Y_row_stride + X_row_ptr = X_ptr + row_idx * X_row_stride + RSTD_row_ptr = RSTD_ptr + row_idx * RSTD_row_stride + + # Accumulator for mean_square computation across all column blocks + sum_square = 0.0 + + # First pass: accumulate sum of squares + for col_block_idx in range(num_col_blocks): + col_start = col_block_idx * BLOCK_SIZE + col_offsets = col_start + offsets + mask = col_offsets < n_cols + + X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0, eviction_policy="evict_first") + + if casting_mode == _CASTING_MODE_LLAMA or casting_mode == _CASTING_MODE_GEMMA: + X_block = X_block.to(tl.float32) + + # Accumulate sum of squares (only for valid elements) + sum_square += tl.sum(X_block * X_block) + + # Compute rstd for this row + mean_square = sum_square / n_cols + + rstd = rsqrt(mean_square + eps) + + # Store rstd + tl.store(RSTD_row_ptr, rstd) + + # Second pass: normalize and multiply by weight + for col_block_idx in range(num_col_blocks): + col_start = col_block_idx * BLOCK_SIZE + col_offsets = col_start + offsets + mask = col_offsets < n_cols + + # Load X_block + X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0, cache_modifier=".ca") + + if elementwise_affine: + W_block = tl.load(W_ptr + col_offsets, mask=mask, other=0.0) + + # Apply casting based on mode + if casting_mode == _CASTING_MODE_GEMMA: + X_block = X_block.to(tl.float32) + if elementwise_affine: + W_block = W_block.to(tl.float32) + elif casting_mode == _CASTING_MODE_LLAMA: + X_block = X_block.to(tl.float32) + + # Normalize + X_block = X_block * rstd + + # Cast back for Llama mode before weight multiplication + if casting_mode == _CASTING_MODE_LLAMA: + X_block = X_block.to(X_DTYPE) + + # Apply weight + if elementwise_affine: + Y_block = X_block * (offset + W_block) + else: + Y_block = X_block + + # Cast back for Gemma mode + if casting_mode == _CASTING_MODE_GEMMA: + Y_block = Y_block.to(X_DTYPE) + + # Store result + tl.store(Y_row_ptr + col_offsets, Y_block, mask=mask) + + +# ----------------------------------------------------------------------------- +# Backward Kernel - No Tiling (for n_cols <= 2048) +# ----------------------------------------------------------------------------- + + +@triton.jit +def _rms_norm_backward_kernel_no_tiling( + dY_ptr, + dY_row_stride, + dX_ptr, + dX_row_stride, + X_ptr, + X_row_stride, + X_dtype: tl.constexpr, + W_ptr, + RSTD_ptr, + RSTD_row_stride, + dW_ptr, + dW_row_stride, + n_rows: tl.constexpr, + n_cols: tl.constexpr, + offset, + casting_mode: tl.constexpr, + elementwise_affine: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + """ + NPU-optimized rms_norm backward kernel for small n_cols (< 2048). + + Performance optimizations: + 1. Keep all data in registers, minimize conversions + 2. Reuse X_normalized (X * rstd) for both dX and dW + 3. Optimize computation order to reduce register pressure + 4. Combine operations where possible + 5. Use 2D vector loading to maximize UB utilization (e.g., (1,2048), (2,1024), (4,512)) + """ + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + + # Grid-stride loop setup for 2D blocks + grid_stride = num_progs * BLOCK_SIZE_M + num_iterations = tl.cdiv(n_rows, grid_stride) + + col_offsets = tl.arange(0, BLOCK_SIZE_N) + col_mask = col_offsets < n_cols + row_offsets = tl.arange(0, BLOCK_SIZE_M) + + # Load W once for all iterations + if elementwise_affine: + W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0) + W_offset = W_row + offset + + # Grid-stride loop over row blocks + for i in range(num_iterations): + row_idx = i * grid_stride + pid * BLOCK_SIZE_M + row_offsets + row_mask = row_idx < n_rows + block_mask = row_mask[:, None] & col_mask[None, :] + + dY_rows = tl.load( + dY_ptr + row_idx[:, None] * dY_row_stride + col_offsets[None, :], + mask=block_mask, + other=0.0, + eviction_policy="evict_first", + ) + X_rows = tl.load( + X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :], + mask=block_mask, + other=0.0, + eviction_policy="evict_first", + ) + + # Load rstd for all rows in the block + rstd_rows = tl.load(RSTD_ptr + row_idx * RSTD_row_stride, mask=row_mask, other=0.0) + + # Convert X to fp32 once + X_rows = X_rows.to(tl.float32) + + # Compute X_normalized (reused in dX and dW) + X_normalized = X_rows * rstd_rows[:, None] + + # Compute m based on casting mode and elementwise_affine + if elementwise_affine: + if casting_mode == _CASTING_MODE_LLAMA: + m_rows = (dY_rows * W_offset[None, :]).to(tl.float32) + # For dW in Llama mode, we need X_normalized in original dtype + X_normalized = X_normalized.to(X_dtype) + elif casting_mode == _CASTING_MODE_GEMMA: + m_rows = dY_rows.to(tl.float32) * W_offset[None, :] + else: + m_rows = dY_rows * W_offset[None, :] + else: + if casting_mode == _CASTING_MODE_LLAMA or casting_mode == _CASTING_MODE_GEMMA: + m_rows = dY_rows.to(tl.float32) + else: + m_rows = dY_rows + + # Compute sum(m * X) for correction factor + sum_m_X = tl.sum(tl.where(block_mask, m_rows * X_rows, 0.0), axis=1) + + # Compute correction factor + correction_factors = -(1.0 / n_cols) * rstd_rows * rstd_rows * sum_m_X + + # Compute dX = rstd * m + rstd * correction_factor * X + dX_rows = rstd_rows[:, None] * m_rows + rstd_rows[:, None] * correction_factors[:, None] * X_rows + + # Store dX + tl.store(dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :], dX_rows.to(X_dtype), mask=block_mask) + + if elementwise_affine: + # Compute dW contribution: dY * X_normalized + dW_rows = (dY_rows * X_normalized).to(tl.float32) + + # Accumulate to per-program dW buffer + dW_row_ptr = dW_ptr + pid * dW_row_stride + existing_dW = tl.load(dW_row_ptr + col_offsets, mask=col_mask, other=0.0) + new_dW = existing_dW + tl.sum(tl.where(block_mask, dW_rows, 0.0), axis=0) + tl.store(dW_row_ptr + col_offsets, new_dW, mask=col_mask) + + +# ----------------------------------------------------------------------------- +# Backward Kernel - With Tiling (for n_cols > 2048) +# ----------------------------------------------------------------------------- + + +@triton.jit +def _rms_norm_backward_kernel_tiled( + dY_ptr, + dY_row_stride, + dX_ptr, + dX_row_stride, + X_ptr, + X_row_stride, + X_dtype: tl.constexpr, + W_ptr, + RSTD_ptr, + RSTD_row_stride, + dW_ptr, + dW_row_stride, + n_rows: tl.constexpr, + n_cols: tl.constexpr, + offset, + casting_mode: tl.constexpr, + elementwise_affine: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + NPU-optimized rms_norm backward kernel for large n_cols (>= 2048). + + Each program processes multiple rows using grid-stride loop. + For each row, we process columns in blocks to avoid UB overflow. + """ + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + + # Initialize dW accumulator (per-program, will be reduced later) + num_col_blocks = tl.cdiv(n_cols, BLOCK_SIZE) + offsets = tl.arange(0, BLOCK_SIZE) + + # Grid-stride loop over rows + for row_idx in tl.range(pid, n_rows, num_progs): + # Base pointers for this row + dY_row_ptr = dY_ptr + row_idx * dY_row_stride + dX_row_ptr = dX_ptr + row_idx * dX_row_stride + X_row_ptr = X_ptr + row_idx * X_row_stride + RSTD_row_ptr = RSTD_ptr + row_idx * RSTD_row_stride + + # Load rstd for this row + rstd = tl.load(RSTD_row_ptr) + + # First pass: compute sum(m * X) for the correction term + sum_m_X = 0.0 + + for col_block_idx in range(num_col_blocks): + col_start = col_block_idx * BLOCK_SIZE + col_offsets = col_start + offsets + mask = col_offsets < n_cols + + dY_block = tl.load(dY_row_ptr + col_offsets, mask=mask, other=0.0, eviction_policy="evict_first") + X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0, eviction_policy="evict_first") + + # Convert to fp32 for computation + X_block = X_block.to(tl.float32) + + if elementwise_affine: + W_block = tl.load(W_ptr + col_offsets, mask=mask, other=0.0, eviction_policy="evict_first") + W_offset = W_block + offset + + # Compute m based on casting mode + if casting_mode == _CASTING_MODE_LLAMA: + m = (dY_block * W_offset).to(tl.float32) + elif casting_mode == _CASTING_MODE_GEMMA: + dY_block = dY_block.to(tl.float32) + m = dY_block * W_offset + else: + m = dY_block * W_offset + else: + # Compute m based on casting mode + if casting_mode == _CASTING_MODE_LLAMA: + m = dY_block.to(tl.float32) + elif casting_mode == _CASTING_MODE_GEMMA: + m = dY_block.to(tl.float32) + else: + m = dY_block + + # Accumulate sum(m * X) + sum_m_X += tl.sum(m * X_block) + + # Compute the correction factor + correction_factor = -(1.0 / n_cols) * rstd * rstd * sum_m_X + + # Second pass: compute gradients + for col_block_idx in range(num_col_blocks): + col_start = col_block_idx * BLOCK_SIZE + col_offsets = col_start + offsets + mask = col_offsets < n_cols + + dY_block = tl.load(dY_row_ptr + col_offsets, mask=mask, other=0.0) + X_block = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0) + + X_block = X_block.to(tl.float32) + + if elementwise_affine: + W_block = tl.load(W_ptr + col_offsets, mask=mask, other=0.0) + W_offset = W_block + offset + + # Compute m based on casting mode + if casting_mode == _CASTING_MODE_LLAMA: + m = (dY_block * W_offset).to(tl.float32) + elif casting_mode == _CASTING_MODE_GEMMA: + dY_block = dY_block.to(tl.float32) + m = dY_block * W_offset + else: + m = dY_block * W_offset + else: + # Compute m based on casting mode + if casting_mode == _CASTING_MODE_LLAMA: + m = dY_block.to(tl.float32) + elif casting_mode == _CASTING_MODE_GEMMA: + m = dY_block.to(tl.float32) + else: + m = dY_block + + # Compute dX + dX_block = rstd * m + rstd * correction_factor * X_block + + # Store dX + tl.store(dX_row_ptr + col_offsets, dX_block.to(X_dtype), mask=mask) + + if elementwise_affine: + # Compute dW contribution (accumulate per program) + if casting_mode == _CASTING_MODE_LLAMA: + dW_block = dY_block * (X_block * rstd).to(X_dtype) + else: + dW_block = dY_block * (X_block * rstd) + + # Atomic add to dW_ptr (each program writes to its own row) + dW_row_ptr = dW_ptr + pid * dW_row_stride + + # Load existing dW, add contribution, store back + existing_dW = tl.load(dW_row_ptr + col_offsets, mask=mask, other=0.0) + new_dW = existing_dW + dW_block.to(tl.float32) + tl.store(dW_row_ptr + col_offsets, new_dW, mask=mask) + + +# ----------------------------------------------------------------------------- +# Helper Functions +# ----------------------------------------------------------------------------- + + +def get_optimal_block_size(n_cols, is_forward: bool): + """ + Calculate optimal block size for forward pass using compute_default_tiling_strategy. + + Memory analysis for forward pass (per row): + - Load: X_block, W_block (2 blocks) + - Compute: X_block (fp32), Y_block (1-2 blocks) + - Total: conservative estimate 6 blocks of memory + + Memory analysis for backward pass (per row): + - Load: dY_block, X_block, W_block (3 blocks) + - Compute: m, dX_block, dW_block (3 blocks) + - Store: dX_block, accumulated dW (2 blocks) + - Total: conservative estimate 8 blocks of memory + + Args: + n_cols: Number of columns in the tensor + is_forward: Whether this is for forward pass + + Returns: + Optimal block size + """ + if n_cols <= 2048: + return triton.next_power_of_2(n_cols) + + memory_multiplier = 6.0 if is_forward else 8.0 + + tile_shapes = compute_default_tiling_strategy( + safety_margin=0.9, + dtype_size=4, + memory_multiplier=memory_multiplier, + shapes=((n_cols,),), + tiling_dims=(0,), + ) + + if tile_shapes and len(tile_shapes) > 0: + block_size = tile_shapes[0][0] + return max(2048, block_size) + else: + return 2048 + + +# ----------------------------------------------------------------------------- +# Forward and Backward Functions +# ----------------------------------------------------------------------------- + + +_str_to_casting_mode = { + "llama": _CASTING_MODE_LLAMA.value, + "gemma": _CASTING_MODE_GEMMA.value, + "none": _CASTING_MODE_NONE.value, +} + + +def rms_norm_forward(X, W, eps, offset, casting_mode): + if not isinstance(casting_mode, int): + assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}" + casting_mode = _str_to_casting_mode[casting_mode] + else: + assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}" + shape = X.shape + dim = shape[-1] + X = X.view(-1, dim) + n_rows, n_cols = X.shape + X_DTYPE = torch_dtype_to_triton(X.dtype) + + # Get optimal block size for column processing + BLOCK_SIZE = get_optimal_block_size(n_cols, True) + BLOCK_SIZE_M = 2048 // BLOCK_SIZE + + Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) + + # RSTD is always fp32 for Llama/Gemma modes + rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype + RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device) + + if W is not None: + # Check constraints + assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension" + elementwise_affine = True + else: + elementwise_affine = False + + # Grid size limited to NPU core count + num_cores = get_npu_core_count() + grid_size = min(num_cores * 2, n_rows) + + # Choose kernel based on n_cols + if n_cols <= 2048: + # Use no-tiling kernel for small n_cols + _rms_norm_forward_kernel_no_tiling[(grid_size,)]( + Y, + Y.stride(0), + X, + X.stride(0), + W, + RSTD, + RSTD.stride(0), + n_rows, + n_cols, + eps, + offset, + casting_mode, + elementwise_affine, + X_DTYPE, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE, + ) + else: + # Use tiled kernel for large n_cols + _rms_norm_forward_kernel_tiled[(grid_size,)]( + Y, + Y.stride(0), + X, + X.stride(0), + W, + RSTD, + RSTD.stride(0), + n_rows, + n_cols, + eps, + offset, + casting_mode, + elementwise_affine, + X_DTYPE, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return Y.view(*shape), X, RSTD, casting_mode + + +def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, in_place): + shape = dY.shape + dim = shape[-1] + dY = dY.view(-1, dim) + n_rows, n_cols = dY.shape + + # Get NPU core count for grid size + num_cores = get_npu_core_count() + grid_size = min(num_cores * 2, n_rows) + + # Get optimal block size for backward pass + BLOCK_SIZE = get_optimal_block_size(n_cols, False) + BLOCK_SIZE_M = 2048 // BLOCK_SIZE + + if W is not None: + # fp32 for numerical stability + _dW = torch.zeros((grid_size, n_cols), dtype=torch.float32, device=W.device) + elementwise_affine = True + else: + _dW = None + elementwise_affine = False + + if in_place: + dX = dY + else: + dX = torch.empty_like(dY) + + # Choose kernel based on n_cols + if n_cols <= 2048: + # Use no-tiling kernel for small n_cols + _rms_norm_backward_kernel_no_tiling[(grid_size,)]( + dY, + dY.stride(0), + dX, + dX.stride(0), + X, + X.stride(0), + torch_to_triton_dtype[X.dtype], + W, + RSTD, + RSTD.stride(0), + _dW, + _dW.stride(0) if elementwise_affine else 0, + n_rows, + n_cols, + offset, + casting_mode, + elementwise_affine, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE, + ) + else: + # Use tiled kernel for large n_cols + _rms_norm_backward_kernel_tiled[(grid_size,)]( + dY, + dY.stride(0), + dX, + dX.stride(0), + X, + X.stride(0), + torch_to_triton_dtype[X.dtype], + W, + RSTD, + RSTD.stride(0), + _dW, + _dW.stride(0) if elementwise_affine else 0, + n_rows, + n_cols, + offset, + casting_mode, + elementwise_affine, + BLOCK_SIZE=BLOCK_SIZE, + ) + + dX = dX.view(*shape) + + if elementwise_affine: + dW = _dW.sum(dim=0).to(W.dtype) + else: + dW = None + + return dX, dW + + +class LigerRMSNormFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row_mode=None): + """ + X: (B, T, H) or (BxT, H) + W: (H,) + """ + if isinstance(X, torch.distributed.tensor.DTensor): + # Input tensor is output of a tensor parallel module and + # needs to be gathered to a local tensor to compute + # RMSE layer norm on each TP worker. + # TODO: support CP. + X = X.full_tensor() + + Y, X, RSTD, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode) + ctx.offset = offset + ctx.casting_mode = casting_mode + ctx.in_place = in_place + ctx.elementwise_affine = W is not None + if W is not None: + ctx.save_for_backward(X, W, RSTD) + else: + ctx.save_for_backward(X, RSTD) + return Y + + @staticmethod + @ensure_contiguous + def backward(ctx, dY): + """ + Y: (B, T, H) or (BxT, H) + """ + if ctx.elementwise_affine: + X, W, RSTD = ctx.saved_tensors + else: + X, RSTD = ctx.saved_tensors + W = None + if isinstance(dY, torch.distributed.tensor.DTensor): + # Gradients are output of a tensor parallel module and + # needs to be gathered to a local tensor for computing RMSE layer. + # TODO: support CP. + dY = dY.full_tensor() + + dX, dW = rms_norm_backward(dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.in_place) + return dX, dW, None, None, None, None, None diff --git a/src/liger_kernel/ops/backends/_ascend/ops/rope.py b/src/liger_kernel/ops/backends/_ascend/ops/rope.py new file mode 100755 index 0000000000000000000000000000000000000000..ba5391f8f654c77fd0288903efdf2ea4adbb82cb --- /dev/null +++ b/src/liger_kernel/ops/backends/_ascend/ops/rope.py @@ -0,0 +1,262 @@ +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy +from liger_kernel.ops.utils import get_npu_core_count + + +@triton.jit +def _triton_rope_npu( + q_ptr, + q_row_stride, + k_ptr, + k_row_stride, + cos, + cos_row_stride, + sin, + sin_row_stride, + sl, + total_rows: tl.constexpr, + cos_bs: tl.constexpr, + n_qh: tl.constexpr, + n_kh: tl.constexpr, + hd: tl.constexpr, + BLOCK_Q: tl.constexpr, + BLOCK_K: tl.constexpr, + BACKWARD_PASS: tl.constexpr = False, +): + program_id = tl.program_id(0) + num_programs = tl.num_programs(0) + + rows_per_program = (total_rows + num_programs - 1) // num_programs + start_row = program_id * rows_per_program + actual_rows = tl.minimum(rows_per_program, total_rows - start_row) + + for row_offset in tl.range(0, actual_rows): + pid = start_row + row_offset + + row_idx = pid % sl + cos_ptr = cos + tl.where(cos_bs == 1, row_idx * cos_row_stride, pid * cos_row_stride) + sin_ptr = sin + tl.where(cos_bs == 1, row_idx * sin_row_stride, pid * sin_row_stride) + + # Pre-compute d_idx and cos/sin values outside loops (they don't depend on heads) + d_idx = tl.arange(0, hd // 2) + d_mask = d_idx < (hd // 2) # Always True, but kept for clarity + cos_vals = tl.load(cos_ptr + d_idx, mask=d_mask, other=0) + sin_vals = tl.load(sin_ptr + d_idx, mask=d_mask, other=0) + + # Process q heads in chunks to prevent UB overflow + for qh_block in range(0, n_qh, BLOCK_Q): + qh_idx = tl.arange(0, BLOCK_Q) + qh_block + qh_mask = qh_idx < n_qh + + # block_mask: qh_mask broadcasted over d_idx dimension + block_mask = qh_mask[:, None] + + offsets = qh_idx[:, None] * hd + d_idx[None, :] + q_base = q_ptr + pid * q_row_stride + + q_left = tl.load(q_base + offsets, mask=block_mask, other=0) + q_right = tl.load(q_base + offsets + (hd // 2), mask=block_mask, other=0) + + if not BACKWARD_PASS: + new_left = q_left * cos_vals - q_right * sin_vals + new_right = q_right * cos_vals + q_left * sin_vals + else: + new_left = q_left * cos_vals + q_right * sin_vals + new_right = q_right * cos_vals - q_left * sin_vals + + tl.store(q_base + offsets, new_left, mask=block_mask) + tl.store(q_base + offsets + (hd // 2), new_right, mask=block_mask) + + # Process k heads in chunks to prevent UB overflow + for kh_block in range(0, n_kh, BLOCK_K): + kh_idx = tl.arange(0, BLOCK_K) + kh_block + kh_mask = kh_idx < n_kh + + # block_mask: kh_mask broadcasted over d_idx dimension + block_mask = kh_mask[:, None] + + offsets = kh_idx[:, None] * hd + d_idx[None, :] + k_base = k_ptr + pid * k_row_stride + + k_left = tl.load(k_base + offsets, mask=block_mask, other=0) + k_right = tl.load(k_base + offsets + (hd // 2), mask=block_mask, other=0) + + if not BACKWARD_PASS: + new_left = k_left * cos_vals - k_right * sin_vals + new_right = k_right * cos_vals + k_left * sin_vals + else: + new_left = k_left * cos_vals + k_right * sin_vals + new_right = k_right * cos_vals - k_left * sin_vals + + tl.store(k_base + offsets, new_left, mask=block_mask) + tl.store(k_base + offsets + (hd // 2), new_right, mask=block_mask) + + +def get_optimal_block_size(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size): + # Compute tiling strategy based on UB capacity + # ROPE forward tiling strategy (based on optimized ROPE kernel): + # - cos_vals and sin_vals are loaded once outside loops (shared): pad_hd // 2 elements each + # - In q heads loop (peak memory): + # * q_left: BLOCK_Q * (pad_hd // 2) elements + # * q_right: BLOCK_Q * (pad_hd // 2) elements + # * new_left: BLOCK_Q * (pad_hd // 2) elements (intermediate result) + # * new_right: BLOCK_Q * (pad_hd // 2) elements (intermediate result) + # * Total: 4 * BLOCK_Q * (pad_hd // 2) = 2 * BLOCK_Q * pad_hd elements + # - In k heads loop (peak memory): + # * k_left: BLOCK_K * (pad_hd // 2) elements + # * k_right: BLOCK_K * (pad_hd // 2) elements + # * new_left: BLOCK_K * (pad_hd // 2) elements (intermediate result) + # * new_right: BLOCK_K * (pad_hd // 2) elements (intermediate result) + # * Total: 4 * BLOCK_K * (pad_hd // 2) = 2 * BLOCK_K * pad_hd elements + # - Since q and k are processed separately, peak memory is max(BLOCK_Q, BLOCK_K) case + # - Plus shared cos/sin: 2 * (pad_hd // 2) = pad_hd elements + # - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + pad_hd) * dtype_size * 8 bits + # - Simplified: (2 * BLOCK_SIZE + 1) * pad_hd * dtype_size * 8 bits + # - For safety, use: memory_multiplier=3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits + # - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd)) + # - tiling_dims: (0, 0) means first dimension of each shape can be tiled + # - Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd)) + shapes = ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd)) + tile_shapes = compute_default_tiling_strategy( + safety_margin=0.90, + dtype_size=dtype_size, + memory_multiplier=3.0, + shapes=shapes, + tiling_dims=(0, 0), + ) + + if tile_shapes is not None and len(tile_shapes) == len(shapes): + # Strategy returns ((block_size_q, pad_hd), (block_size_kv, pad_hd)) + q_tile_shape, k_tile_shape = tile_shapes + BLOCK_Q, _ = q_tile_shape + BLOCK_K, _ = k_tile_shape + else: + # Fallback to conservative defaults + BLOCK_Q = 2048 + BLOCK_K = 2048 + + return BLOCK_Q, BLOCK_K + + +def rope_forward(q, k, cos, sin): + # transpose it back to the physical shape because Triton looks at the physical storage + # note: q and k are incontiguous before the transformation and will become contiguous after transpose + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + batch_size, seq_len, n_q_head, head_dim = q.shape + n_kv_head = k.shape[2] + pad_hd = triton.next_power_of_2(head_dim) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + + n_row = batch_size * seq_len + + # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous + q = q.contiguous() + k = k.contiguous() + cos = cos.contiguous() + sin = sin.contiguous() + cos_batch_size = cos.shape[0] + + dtype_size = q.element_size() + BLOCK_Q, BLOCK_K = get_optimal_block_size(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size) + + num_cores = get_npu_core_count() + grid_size = min(num_cores, n_row) + + _triton_rope_npu[(grid_size,)]( + q, + q.stride(1), + k, + k.stride(1), + cos, + cos.stride(-2), + sin, + sin.stride(-2), + seq_len, + n_row, + cos_batch_size, + n_q_head, + n_kv_head, + head_dim, + BLOCK_Q, + BLOCK_K, + BACKWARD_PASS=False, + ) + return q.transpose(1, 2), k.transpose(1, 2), cos, sin + + +def rope_backward(dq, dk, cos, sin): + dq = dq.transpose(1, 2) + dk = dk.transpose(1, 2) + + batch_size, seq_len, n_q_head, head_dim = dq.shape + cos_batch_size = cos.shape[0] + n_kv_head = dk.shape[2] + pad_hd = triton.next_power_of_2(head_dim) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + + n_row = batch_size * seq_len + + # ensure dq and dk are contiguous + dq = dq.contiguous() + dk = dk.contiguous() + + dtype_size = dq.element_size() + BLOCK_Q, BLOCK_K = get_optimal_block_size(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size) + + num_cores = get_npu_core_count() + grid_size = min(num_cores, n_row) + + _triton_rope_npu[(grid_size,)]( + dq, + dq.stride(1), + dk, + dk.stride(1), + cos, + cos.stride(-2), + sin, + sin.stride(-2), + seq_len, + n_row, + cos_batch_size, + n_q_head, + n_kv_head, + head_dim, + BLOCK_Q, + BLOCK_K, + BACKWARD_PASS=True, + ) + return dq.transpose(1, 2), dk.transpose(1, 2) + + +class LigerRopeFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """ + q size: (bsz, n_q_head, seq_len, head_dim) + k size: (bsz, n_kv_head, seq_len, head_dim) + cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + """ + q, k, cos, sin = rope_forward(q, k, cos, sin) + ctx.save_for_backward(cos, sin) + return q, k + + @staticmethod + def backward(ctx, dq, dk): + """ + dq size: (bsz, n_q_head, seq_len, head_dim) + dk size: (bsz, n_kv_head, seq_len, head_dim) + cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + """ + + cos, sin = ctx.saved_tensors + dq, dk = rope_backward(dq, dk, cos, sin) + return dq, dk, None, None, None, None diff --git a/src/liger_kernel/ops/backends/_ascend/ops/softmax.py b/src/liger_kernel/ops/backends/_ascend/ops/softmax.py new file mode 100755 index 0000000000000000000000000000000000000000..d9144d5c8a87d60dc18b51ab59fe92a4ad42db84 --- /dev/null +++ b/src/liger_kernel/ops/backends/_ascend/ops/softmax.py @@ -0,0 +1,344 @@ +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.utils import ensure_contiguous +from liger_kernel.ops.utils import get_npu_core_count + + +@triton.jit +def _softmax_single_block_forward_kernel( + Y_ptr, + Y_row_stride, + X_ptr, + X_row_stride, + n_rows: tl.constexpr, + n_cols: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + ROWS_PER_BLOCK: tl.constexpr, +): + """ + Single-block softmax forward kernel for small column sizes. + + Processes entire row in one block when n_cols <= BLOCK_SIZE. + Uses 2D tensor to process multiple rows simultaneously for better UB utilization. + + Args: + Y_ptr: Output tensor pointer + Y_row_stride: Stride for output rows + X_ptr: Input tensor pointer + X_row_stride: Stride for input rows + n_rows: Number of rows to process + n_cols: Number of columns per row + BLOCK_SIZE: Block size for column processing + ROWS_PER_BLOCK: Number of rows to process simultaneously + """ + row_block_start = tl.program_id(0) * ROWS_PER_BLOCK + row_block_step = tl.num_programs(0) * ROWS_PER_BLOCK + + row_offsets = tl.arange(0, ROWS_PER_BLOCK) + col_offsets = tl.arange(0, BLOCK_SIZE) + + for row_block_idx in tl.range(row_block_start, n_rows, row_block_step): + row_idx = row_block_idx + row_offsets + row_mask = row_idx < n_rows + col_mask = col_offsets < n_cols + + # 2D mask: [ROWS_PER_BLOCK, BLOCK_SIZE] + mask = row_mask[:, None] & col_mask[None, :] + + # Load 2D block: [ROWS_PER_BLOCK, BLOCK_SIZE] + offsets = row_idx[:, None] * X_row_stride + col_offsets[None, :] + x = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")) + + # Compute softmax per row (axis=1) + m = tl.max(x, axis=1) + e = tl.exp(x - m[:, None]) + d = tl.sum(e, axis=1) + y = e / d[:, None] + + # Store 2D block + offsets = row_idx[:, None] * Y_row_stride + col_offsets[None, :] + tl.store(Y_ptr + offsets, y, mask=mask) + + +@triton.jit +def _softmax_multi_block_forward_kernel( + Y_ptr, + Y_row_stride, + X_ptr, + X_row_stride, + n_rows: tl.constexpr, + n_cols: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + Multi-block softmax forward kernel using two-pass algorithm. + + First pass computes max and sum for numerical stability. + Second pass normalizes and writes output. + + Args: + Y_ptr: Output tensor pointer + Y_row_stride: Stride for output rows + X_ptr: Input tensor pointer + X_row_stride: Stride for input rows + n_rows: Number of rows to process + n_cols: Number of columns per row + BLOCK_SIZE: Block size for column processing + """ + row_start = tl.program_id(0) + num_prog = tl.num_programs(0) + col_offsets = tl.arange(0, BLOCK_SIZE) + + for row_idx in tl.range(row_start, n_rows, num_prog): + row_start_ptr = X_ptr + row_idx * X_row_stride + m = tl.float32(float("-inf")) + d = tl.float32(0.0) + + for start in tl.range(0, n_cols, BLOCK_SIZE): + idx = start + col_offsets + mask = idx < n_cols + xblk = tl.load( + row_start_ptr + idx, mask=mask, other=float("-inf"), eviction_policy="evict_first", cache_modifier=".ca" + ) + blk_max = tl.max(xblk, axis=0) + new_m = tl.maximum(m, blk_max) + d = d * tl.exp(m - new_m) + tl.sum(tl.exp(xblk - new_m), axis=0) + m = new_m + + for start in tl.range(0, n_cols, BLOCK_SIZE): + idx = start + col_offsets + mask = idx < n_cols + xblk = tl.load( + row_start_ptr + idx, mask=mask, other=float("-inf"), eviction_policy="evict_first", cache_modifier=".ca" + ) + yblk = tl.exp(xblk - m) / d + tl.store(Y_ptr + row_idx * Y_row_stride + idx, yblk, mask=mask, cache_modifier=".cs") + + +@triton.jit +def _softmax_single_block_backward_kernel( + dy_ptr, + dy_stride, + y_ptr, + y_stride, + dx_ptr, + dx_stride, + n_rows: tl.constexpr, + n_cols: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + ROWS_PER_BLOCK: tl.constexpr, +): + """ + Single-block softmax backward kernel for small column sizes. + + Computes gradient: dx = y * (dy - sum(dy * y)) + Uses 2D tensor to process multiple rows simultaneously for better UB utilization. + + Args: + dy_ptr: Gradient output pointer + dy_stride: Stride for gradient output rows + y_ptr: Forward output pointer + y_stride: Stride for forward output rows + dx_ptr: Gradient input pointer + dx_stride: Stride for gradient input rows + n_rows: Number of rows to process + n_cols: Number of columns per row + BLOCK_SIZE: Block size for column processing + ROWS_PER_BLOCK: Number of rows to process simultaneously + """ + row_block_start = tl.program_id(0) * ROWS_PER_BLOCK + row_block_step = tl.num_programs(0) * ROWS_PER_BLOCK + + row_offsets = tl.arange(0, ROWS_PER_BLOCK) + col_offsets = tl.arange(0, BLOCK_SIZE) + + for row_block_idx in tl.range(row_block_start, n_rows, row_block_step): + row_idx = row_block_idx + row_offsets + row_mask = row_idx < n_rows + col_mask = col_offsets < n_cols + + # 2D mask: [ROWS_PER_BLOCK, BLOCK_SIZE] + mask = row_mask[:, None] & col_mask[None, :] + + # Load 2D blocks: [ROWS_PER_BLOCK, BLOCK_SIZE] + dy_offsets = row_idx[:, None] * dy_stride + col_offsets[None, :] + y_offsets = row_idx[:, None] * y_stride + col_offsets[None, :] + + dy = tl.load(dy_ptr + dy_offsets, mask=mask, other=0.0) + y = tl.load(y_ptr + y_offsets, mask=mask, other=0.0) + + # Compute dot product per row (axis=1) + dot = tl.sum(dy * y, axis=1) + dx = y * (dy - dot[:, None]) + + # Store 2D block + dx_offsets = row_idx[:, None] * dx_stride + col_offsets[None, :] + tl.store(dx_ptr + dx_offsets, dx, mask=mask) + + +@triton.jit +def _softmax_multi_block_backward_kernel( + dy_ptr, + dy_stride, + y_ptr, + y_stride, + dx_ptr, + dx_stride, + n_rows: tl.constexpr, + n_cols: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + Multi-block softmax backward kernel using two-pass algorithm. + + Computes gradient: dx = y * (dy - sum(dy * y)) + + Args: + dy_ptr: Gradient output pointer + dy_stride: Stride for gradient output rows + y_ptr: Forward output pointer + y_stride: Stride for forward output rows + dx_ptr: Gradient input pointer + dx_stride: Stride for gradient input rows + n_rows: Number of rows to process + n_cols: Number of columns per row + BLOCK_SIZE: Block size for column processing + """ + row_start = tl.program_id(0) + num_prog = tl.num_programs(0) + col_offsets = tl.arange(0, BLOCK_SIZE) + + for row_idx in tl.range(row_start, n_rows, num_prog): + dy_start_ptr = dy_ptr + row_idx * dy_stride + y_start_ptr = y_ptr + row_idx * y_stride + acc = 0.0 + + for start in tl.range(0, n_cols, BLOCK_SIZE): + idx = start + col_offsets + mask = idx < n_cols + dy_blk = tl.load(dy_start_ptr + idx, mask=mask, other=0.0, eviction_policy="evict_first") + y_blk = tl.load( + y_start_ptr + idx, mask=mask, other=0.0, eviction_policy="evict_first", cache_modifier=".ca" + ) + acc += tl.sum(dy_blk * y_blk, axis=0) + + for start in tl.range(0, n_cols, BLOCK_SIZE): + idx = start + col_offsets + mask = idx < n_cols + dy_blk = tl.load(dy_start_ptr + idx, mask=mask, other=0.0) + y_blk = tl.load(y_start_ptr + idx, mask=mask, other=0.0, cache_modifier=".ca") + dx_blk = y_blk * (dy_blk - acc) + tl.store(dx_ptr + row_idx * dx_stride + idx, dx_blk, mask=mask, cache_modifier=".wb") + + +def _softmax_forward(x): + *batch, n_cols = x.shape + x2d = x.contiguous().view(-1, n_cols) + n_rows = x2d.shape[0] + MAX_FUSED_BLOCK_SIZE = 8192 + + BLOCK_SIZE = triton.next_power_of_2(n_cols) + BLOCK_SIZE = min(BLOCK_SIZE, MAX_FUSED_BLOCK_SIZE) + + y2d = torch.empty_like(x2d) + num_cores = get_npu_core_count() + + if n_cols <= BLOCK_SIZE: + # Calculate optimal ROWS_PER_BLOCK to utilize UB efficiently + # Target: ROWS_PER_BLOCK * BLOCK_SIZE <= MAX_FUSED_BLOCK_SIZE + ROWS_PER_BLOCK = min(MAX_FUSED_BLOCK_SIZE // BLOCK_SIZE, 32) + ROWS_PER_BLOCK = triton.next_power_of_2(ROWS_PER_BLOCK) + + # Calculate number of programs needed + num_row_blocks = (n_rows + ROWS_PER_BLOCK - 1) // ROWS_PER_BLOCK + num_programs = min(num_cores, num_row_blocks) + + _softmax_single_block_forward_kernel[(num_programs,)]( + y2d, y2d.stride(0), x2d, x2d.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, ROWS_PER_BLOCK=ROWS_PER_BLOCK + ) + multi_block_launch = False + else: + num_programs = min(num_cores, n_rows) + ROWS_PER_BLOCK = 1 # Not used in multi-block + + _softmax_multi_block_forward_kernel[(num_programs,)]( + y2d, y2d.stride(0), x2d, x2d.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE + ) + multi_block_launch = True + + return y2d.view(*batch, n_cols), BLOCK_SIZE, ROWS_PER_BLOCK, multi_block_launch + + +def _softmax_backward( + dy: torch.Tensor, + y: torch.Tensor, + BLOCK_SIZE: int, + ROWS_PER_BLOCK: int, + multi_block_launch: bool, +) -> torch.Tensor: + *batch, n_cols = dy.shape + dy2d = dy.contiguous().view(-1, n_cols) + y2d = y.contiguous().view(-1, n_cols) + n_rows = dy2d.shape[0] + dx2d = torch.empty_like(dy2d) + + num_cores = get_npu_core_count() + + if not multi_block_launch and n_cols <= BLOCK_SIZE: + num_row_blocks = (n_rows + ROWS_PER_BLOCK - 1) // ROWS_PER_BLOCK + num_programs = min(num_cores, num_row_blocks) + _softmax_single_block_backward_kernel[(num_programs,)]( + dy2d, + dy2d.stride(0), + y2d, + y2d.stride(0), + dx2d, + dx2d.stride(0), + n_rows, + n_cols, + BLOCK_SIZE=BLOCK_SIZE, + ROWS_PER_BLOCK=ROWS_PER_BLOCK, + ) + else: + num_programs = min(num_cores, n_rows) + + _softmax_multi_block_backward_kernel[(num_programs,)]( + dy2d, + dy2d.stride(0), + y2d, + y2d.stride(0), + dx2d, + dx2d.stride(0), + n_rows, + n_cols, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return dx2d.view(*batch, n_cols) + + +class LigerSoftmaxFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, input_: torch.Tensor): + y, BLOCK_SIZE, ROWS_PER_BLOCK, multi_block_launch = _softmax_forward(input_) + ctx.save_for_backward(y) + ctx.BLOCK_SIZE = BLOCK_SIZE + ctx.ROWS_PER_BLOCK = ROWS_PER_BLOCK + ctx.multi_block_launch = multi_block_launch + return y + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output): + (y,) = ctx.saved_tensors + dx = _softmax_backward( + grad_output, + y, + ctx.BLOCK_SIZE, + ctx.ROWS_PER_BLOCK, + ctx.multi_block_launch, + ) + return dx diff --git a/src/liger_kernel/ops/backends/_ascend/ops/sparsemax.py b/src/liger_kernel/ops/backends/_ascend/ops/sparsemax.py new file mode 100755 index 0000000000000000000000000000000000000000..a6deaf8af89cfc62458e722ab61ba56bcb76562b --- /dev/null +++ b/src/liger_kernel/ops/backends/_ascend/ops/sparsemax.py @@ -0,0 +1,385 @@ +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy +from liger_kernel.ops.utils import ensure_contiguous +from liger_kernel.ops.utils import get_npu_core_count + + +@triton.jit +def _sparsemax_forward_kernel( + x_ptr, + x_stride_row, + sorted_x_ptr, + sorted_x_stride_row, + o_ptr, + o_stride_row, + n_rows: tl.constexpr, + n_cols: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Sparsemax forward kernel for rows where n_cols <= BLOCK_SIZE. + + Args: + x_ptr: pointer to input tensor [n_rows, n_cols], fp32. + x_stride_row: row stride of x. + sorted_x_ptr: pointer to x sorted descending along last dim, fp32. + sorted_x_stride_row: row stride of sorted_x. + o_ptr: pointer to output tensor [n_rows, n_cols]. + o_stride_row: row stride of o. + n_rows: number of rows (constexpr). + n_cols: number of columns (constexpr). + BLOCK_SIZE: tile size >= n_cols (constexpr). + """ + pid_row = tl.program_id(0) + num_progs = tl.num_programs(0) + + for row in tl.range(pid_row, n_rows, num_progs): + ptr_x_data_row = x_ptr + row * x_stride_row + ptr_sorted_x_data_row = sorted_x_ptr + row * sorted_x_stride_row + ptr_output_row = o_ptr + row * o_stride_row + + offs = tl.arange(0, BLOCK_SIZE) + mask = offs < n_cols + + z_sorted_block = tl.load( + ptr_sorted_x_data_row + offs, + mask=mask, + other=-float("inf"), + cache_modifier=".cg", + ).to(tl.float32) + + z_valid = tl.where(mask, z_sorted_block, 0.0) + cssv = tl.cumsum(z_valid, 0) + + r = (offs + 1).to(tl.float32) + t_vec = (cssv - 1.0) / r + support = (z_sorted_block > t_vec) & mask + + k_int = tl.sum(support.to(tl.int32), 0) + k_clamped_int = tl.maximum(k_int, 1) + k = k_clamped_int.to(tl.float32) + + s = tl.sum(tl.where(support, z_sorted_block, 0.0), 0) + tau = (s - 1.0) / k + + x_block = tl.load( + ptr_x_data_row + offs, + mask=mask, + other=0.0, + cache_modifier=".cg", + ).to(tl.float32) + + y = tl.maximum(x_block - tau, 0.0) + + tl.store( + ptr_output_row + offs, + y.to(ptr_output_row.dtype.element_ty), + mask=mask, + cache_modifier=".cs", + ) + + +@triton.jit +def _sparsemax_forward_tiled_kernel( + x_ptr, + x_stride_row, + sorted_x_ptr, + sorted_x_stride_row, + o_ptr, + o_stride_row, + n_rows: tl.constexpr, + n_cols: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Sparsemax forward kernel for rows where n_cols > BLOCK_SIZE (tiled). + + Args: + x_ptr: pointer to input tensor [n_rows, n_cols], fp32. + x_stride_row: row stride of x. + sorted_x_ptr: pointer to x sorted descending along last dim, fp32. + sorted_x_stride_row: row stride of sorted_x. + o_ptr: pointer to output tensor [n_rows, n_cols]. + o_stride_row: row stride of o. + n_rows: number of rows (constexpr). + n_cols: number of columns (constexpr). + BLOCK_SIZE: tile size < n_cols (constexpr). + """ + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + + for row in tl.range(pid, n_rows, num_progs): + sorted_row_ptr = sorted_x_ptr + row * sorted_x_stride_row + x_row_ptr = x_ptr + row * x_stride_row + out_row_ptr = o_ptr + row * o_stride_row + offs = tl.arange(0, BLOCK_SIZE) + + # ------------------------------------------------------------------ + # Pass 1: find tau from sorted data + # Since data is sorted descending, support is a contiguous prefix, + # so k = sum(support) — no need for max(support_r), saves one reduction. + # ------------------------------------------------------------------ + running_sum = tl.zeros((), tl.float32) + k = tl.zeros((), tl.int32) + sum_support = tl.zeros((), tl.float32) + + for tile in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)): + idx = tile * BLOCK_SIZE + offs + mask = idx < n_cols + + z = tl.load(sorted_row_ptr + idx, mask=mask, other=0.0, cache_modifier=".ca").to(tl.float32) + + cssv = tl.cumsum(z, axis=0) + running_sum + r = (idx + 1).to(tl.float32) + t = (cssv - 1.0) / r + support = (z > t) & mask + + k += tl.sum(support.to(tl.int32), axis=0) + sum_support += tl.sum(tl.where(support, z, 0.0), axis=0) + running_sum += tl.sum(z, axis=0) + + tau = (sum_support - 1.0) / tl.maximum(k, 1).to(tl.float32) + + # ------------------------------------------------------------------ + # Pass 2: write output y = max(x - tau, 0) + # ------------------------------------------------------------------ + for tile in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)): + idx = tile * BLOCK_SIZE + offs + mask = idx < n_cols + + x = tl.load(x_row_ptr + idx, mask=mask, other=0.0, cache_modifier=".ca").to(tl.float32) + y = tl.maximum(x - tau, 0.0) + + tl.store(out_row_ptr + idx, y.to(out_row_ptr.dtype.element_ty), mask=mask, cache_modifier=".cs") + + +@triton.jit +def _sparsemax_backward_kernel( + o_ptr, + go_ptr, + gi_ptr, + stride, + n_rows: tl.constexpr, + n_cols: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Sparsemax backward kernel for rows where n_cols <= BLOCK_SIZE. + + Args: + o_ptr: pointer to forward output [n_rows, n_cols], fp32. + go_ptr: pointer to upstream gradient [n_rows, n_cols]. + gi_ptr: pointer to input gradient output [n_rows, n_cols]. + stride: common row stride for o, go, gi. + n_rows: number of rows (constexpr). + n_cols: number of columns (constexpr). + BLOCK_SIZE: tile size >= n_cols (constexpr). + """ + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + for row in tl.range(pid, n_rows, num_progs): + o_row = o_ptr + row * stride + go_row = go_ptr + row * stride + gi_row = gi_ptr + row * stride + + offs = tl.arange(0, BLOCK_SIZE) + mask = offs < n_cols + + o_val = tl.load(o_row + offs, mask=mask, other=0.0).to(tl.float32) + go_val = tl.load(go_row + offs, mask=mask, other=0.0).to(tl.float32) + supp = (o_val > 0.0) & mask + + go_sum = tl.sum(tl.where(supp, go_val, 0.0), axis=0) + supp_cnt = tl.sum(supp.to(tl.float32), axis=0) + + gi_val = tl.where( + supp, + go_val - go_sum / tl.maximum(supp_cnt, 1.0), + 0.0, + ) + tl.store(gi_row + offs, gi_val.to(gi_row.dtype.element_ty), mask=mask) + + +@triton.jit +def _sparsemax_backward_tiled_kernel( + o_ptr, go_ptr, gi_ptr, stride, n_rows: tl.constexpr, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr +): + """Sparsemax backward kernel for rows where n_cols > BLOCK_SIZE (tiled). + + Args: + o_ptr: pointer to forward output [n_rows, n_cols], fp32. + go_ptr: pointer to upstream gradient [n_rows, n_cols]. + gi_ptr: pointer to input gradient output [n_rows, n_cols]. + stride: common row stride for o, go, gi. + n_rows: number of rows (constexpr). + n_cols: number of columns (constexpr). + BLOCK_SIZE: tile size < n_cols (constexpr). + """ + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + for row in tl.range(pid, n_rows, num_progs): + o_row = o_ptr + row * stride + go_row = go_ptr + row * stride + gi_row = gi_ptr + row * stride + + offs = tl.arange(0, BLOCK_SIZE) + + supp_cnt = tl.zeros((), tl.float32) + go_sum = tl.zeros((), tl.float32) + + for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)): + offs_iter = i * BLOCK_SIZE + offs + mask_iter = offs_iter < n_cols + o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32) + go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32) + supp = o_val > 0 + go_sum += tl.sum(tl.where(supp, go_val, 0.0)) + supp_cnt += tl.sum(supp.to(tl.float32)) + + for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)): + offs_iter = i * BLOCK_SIZE + offs + mask_iter = offs_iter < n_cols + o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32) + go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32) + + supp = o_val > 0 + gi_val = tl.where( + supp, + go_val - tl.cast(go_sum / tl.maximum(supp_cnt, 1e-6), gi_row.dtype.element_ty).to(tl.float32), + 0.0, + ) + tl.store(gi_row + offs_iter, gi_val.to(gi_row.dtype.element_ty), mask=mask_iter, cache_modifier=".cs") + + +def sparsemax_forward(x, dim): + if dim < 0: + dim += x.dim() + + x_sw = x.transpose(dim, -1).contiguous() + n_cols = x_sw.size(-1) + n_rows = x_sw.numel() // n_cols + x_flat = x_sw.view(n_rows, n_cols) + + x_flat_fp32 = x_flat if x_flat.dtype == torch.float32 else x_flat.float() + x_sorted_flat = torch.sort(x_flat_fp32, dim=-1, descending=True).values + + tile_shapes = compute_default_tiling_strategy( + safety_margin=0.9, + dtype_size=4, + memory_multiplier=12.0, + shapes=((n_cols,),), + tiling_dims=(0,), + ) + + if tile_shapes and len(tile_shapes) > 0: + BLOCK_SIZE = tile_shapes[0][0] + else: + BLOCK_SIZE = 2048 + + out_flat = torch.empty_like(x_flat_fp32) + grid = (min(n_rows, get_npu_core_count()),) + + if n_cols <= BLOCK_SIZE: + # non-tiled kernel: single load covers whole row + _sparsemax_forward_kernel[grid]( + x_flat_fp32, + x_flat_fp32.stride(0), + x_sorted_flat, + x_sorted_flat.stride(0), + out_flat, + out_flat.stride(0), + n_rows, + n_cols, + BLOCK_SIZE=BLOCK_SIZE, + ) + else: + # tiled kernel: compute tau and write output in one fused kernel + _sparsemax_forward_tiled_kernel[grid]( + x_flat_fp32, + x_flat_fp32.stride(0), + x_sorted_flat, + x_sorted_flat.stride(0), + out_flat, + out_flat.stride(0), + n_rows, + n_cols, + BLOCK_SIZE=BLOCK_SIZE, + ) + + y = out_flat.view(x_sw.shape).transpose(dim, -1) + return y, out_flat + + +def sparsemax_backward( + grad_out: torch.Tensor, + out_flat: torch.Tensor, + dim: int, +) -> torch.Tensor: + if dim < 0: + dim += grad_out.dim() + + grad_sw = grad_out.transpose(dim, -1).contiguous() + n_cols = grad_sw.size(-1) + n_rows = grad_sw.numel() // n_cols + go_flat = grad_sw.view(n_rows, n_cols) + + dx_flat = torch.empty_like(go_flat).contiguous() + grid = (min(n_rows, get_npu_core_count()),) + + # use single-pass kernel when feasible + if n_cols <= 4096: + BLOCK_SIZE = triton.next_power_of_2(n_cols) + _sparsemax_backward_kernel[grid]( + out_flat, + go_flat, + dx_flat, + out_flat.stride(0), + n_rows, + n_cols, + BLOCK_SIZE=BLOCK_SIZE, + ) + + else: + # use tiling strategy for very large n_cols: ~10 live buffers at peak = 10.0 multiplier + tile_shapes = compute_default_tiling_strategy( + safety_margin=0.9, + dtype_size=4, + memory_multiplier=8.0, + shapes=((n_cols,),), + tiling_dims=(0,), + ) + + if tile_shapes and len(tile_shapes) > 0: + BLOCK_SIZE = tile_shapes[0][0] + else: + BLOCK_SIZE = 2048 + + _sparsemax_backward_tiled_kernel[grid]( + out_flat, + go_flat, + dx_flat, + out_flat.stride(0), + n_rows, + n_cols, + BLOCK_SIZE=BLOCK_SIZE, + ) + + dx = dx_flat.view_as(grad_sw).transpose(dim, -1) + return dx + + +class LigerSparsemaxFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, x: torch.Tensor, dim: int): + y, out_flat = sparsemax_forward(x, dim) + ctx.save_for_backward(out_flat) + ctx.dim = dim + return y + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_out: torch.Tensor): + (out_flat,) = ctx.saved_tensors + dx = sparsemax_backward(grad_out, out_flat, ctx.dim) + return dx, None diff --git a/src/liger_kernel/ops/backends/_ascend/ops/swiglu.py b/src/liger_kernel/ops/backends/_ascend/ops/swiglu.py new file mode 100755 index 0000000000000000000000000000000000000000..9c244742e0c375d9eb08c424b75b244bbdd7771e --- /dev/null +++ b/src/liger_kernel/ops/backends/_ascend/ops/swiglu.py @@ -0,0 +1,136 @@ +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy +from liger_kernel.ops.utils import get_npu_core_count + +# ----------------------------------------------------------------------------- +# Kernels (High-performance 1D Flatten Implementation) +# ----------------------------------------------------------------------------- + + +@triton.jit +def _swiglu_forward_kernel_flat(a_ptr, b_ptr, c_ptr, total_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + + # Grid-Stride Loop + start_idx = pid * BLOCK_SIZE + stride = num_progs * BLOCK_SIZE + + for idx in tl.range(start_idx, total_elements, stride): + offsets = idx + tl.arange(0, BLOCK_SIZE) + mask = offsets < total_elements + + a_val = tl.load(a_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + b_val = tl.load(b_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + res = (a_val * tl.sigmoid(a_val)) * b_val + tl.store(c_ptr + offsets, res, mask=mask) + + +@triton.jit +def _swiglu_backward_kernel_flat(dc_ptr, a_ptr, b_ptr, da_ptr, db_ptr, total_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + num_progs = tl.num_programs(0) + start_idx = pid * BLOCK_SIZE + stride = num_progs * BLOCK_SIZE + + for idx in tl.range(start_idx, total_elements, stride): + offsets = idx + tl.arange(0, BLOCK_SIZE) + mask = offsets < total_elements + + dc = tl.load(dc_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + a = tl.load(a_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + b = tl.load(b_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + + sig_a = tl.sigmoid(a) + silu_a = a * sig_a + term1 = silu_a * (1.0 - sig_a) + sig_a + + db = dc * silu_a + da = dc * b * term1 + + tl.store(da_ptr + offsets, da, mask=mask) + tl.store(db_ptr + offsets, db, mask=mask) + + +# ----------------------------------------------------------------------------- +# Helper: Call compute_default_tiling_strategy +# ----------------------------------------------------------------------------- + + +def get_optimal_block_size(total_elements, is_backward=False): + """ + Calculate optimal Block Size using compute_default_tiling_strategy + """ + # 1. Set Memory Multiplier + # Forward is lighter, Backward requires more memory for intermediate variables + # 8.0 and 12.0 are empirical values based on Atlas 800I A2 UB (192KB) + multiplier = 12.0 if is_backward else 8.0 + + # 2. Call calculation function + # Treat input as 1D (total_elements,), only tiling on dim 0 + tile_shapes = compute_default_tiling_strategy( + safety_margin=0.9, dtype_size=4, memory_multiplier=multiplier, shapes=((total_elements,),), tiling_dims=(0,) + ) + + # 3. Parse result + if tile_shapes and len(tile_shapes) > 0: + block_size = tile_shapes[0][0] + return max(256, block_size) + else: + return 2048 + + +def swiglu_forward(a, b): + if not a.is_contiguous(): + a = a.contiguous() + if not b.is_contiguous(): + b = b.contiguous() + + total_elements = a.numel() + c = torch.empty_like(a) + + block_size = get_optimal_block_size(total_elements, is_backward=False) + + num_cores = get_npu_core_count() + grid_size = min(num_cores, (total_elements + block_size - 1) // block_size) + + _swiglu_forward_kernel_flat[(grid_size,)](a, b, c, total_elements, BLOCK_SIZE=block_size) + return c + + +def swiglu_backward(a, b, dc): + if not dc.is_contiguous(): + dc = dc.contiguous() + if not a.is_contiguous(): + a = a.contiguous() + if not b.is_contiguous(): + b = b.contiguous() + + total_elements = dc.numel() + grad_a = torch.empty_like(a) + grad_b = torch.empty_like(b) + + block_size = get_optimal_block_size(total_elements, is_backward=True) + + num_cores = get_npu_core_count() + grid_size = min(num_cores, (total_elements + block_size - 1) // block_size) + + _swiglu_backward_kernel_flat[(grid_size,)](dc, a, b, grad_a, grad_b, total_elements, BLOCK_SIZE=block_size) + return grad_a, grad_b + + +class LigerSiLUMulFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, a, b): + c = swiglu_forward(a, b) + ctx.save_for_backward(a, b) + return c + + @staticmethod + def backward(ctx, dc): + a, b = ctx.saved_tensors + grad_a, grad_b = swiglu_backward(a, b, dc) + return grad_a, grad_b diff --git a/src/liger_kernel/ops/backends/_ascend/ops/tvd.py b/src/liger_kernel/ops/backends/_ascend/ops/tvd.py new file mode 100755 index 0000000000000000000000000000000000000000..62a913a09a8c9532bf7523896cbfb2cee773f288 --- /dev/null +++ b/src/liger_kernel/ops/backends/_ascend/ops/tvd.py @@ -0,0 +1,221 @@ +from typing import Literal +from typing import Optional + +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy +from liger_kernel.ops.utils import ensure_contiguous +from liger_kernel.ops.utils import get_npu_core_count + +MAX_FUSED_SIZE = 65536 // 4 + +REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"] + + +@triton.jit +def _tv_distance_kernel( + p_ptr, + p_stride, + q_ptr, + q_stride, + loss_ptr, + loss_stride, + grads_ptr, + grads_stride, + label_ptr, + ignore_index: tl.constexpr, + n_cols, # V + total_rows: tl.constexpr, # BT + BLOCK_SIZE: tl.constexpr, + HAS_LABEL: tl.constexpr, + reduction: tl.constexpr = "batchmean", +): + thread_id = tl.program_id(0) + num_threads = tl.num_programs(0) + + for pid in tl.range(thread_id, total_rows, num_threads): + p_row_ptr = p_ptr + pid * p_stride + q_row_ptr = q_ptr + pid * q_stride + loss_row_ptr = loss_ptr + pid * loss_stride + grads_row_ptr = grads_ptr + pid * grads_stride + label_row_ptr = label_ptr + pid + + base_offsets = tl.arange(0, BLOCK_SIZE) + + should_skip = False + if HAS_LABEL: + label = tl.load(label_row_ptr) + if label == ignore_index: + should_skip = True + + if should_skip: + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + base_offsets + mask = offsets < n_cols + tl.store(grads_row_ptr + offsets, 0.0, mask=mask) + if reduction == "none": + tl.store(loss_row_ptr + offsets, 0.0, mask=mask) + else: + loss_sum = 0.0 + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + base_offsets + mask = offsets < n_cols + + p = tl.load(p_row_ptr + offsets, mask=mask, other=0.0) + q = tl.load(q_row_ptr + offsets, mask=mask, other=0.0) + + # TVD(P || Q) = 0.5 * |P - Q| + tv_loss = 0.5 * tl.abs(p - q) + grad_res = tl.where(p > q, 0.5, -0.5) + + tl.store(grads_row_ptr + offsets, grad_res, mask=mask) + + if reduction == "none": + tl.store(loss_row_ptr + offsets, tv_loss, mask=mask) + else: + loss_sum += tl.sum(tv_loss, axis=0) + + if reduction != "none": + tl.store(loss_row_ptr, loss_sum) + + +def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label): + BT, V = p.shape + + # TVD forward tiling strategy + # - In main loop (calculate loss and grad): + # * p: BLOCK_Q elements + # * q: BLOCK_Q elements + # * tv_loss: BLOCK_Q elements + # * grad_res: BLOCK_Q elements + # * loss_sum: BLOCK_Q elements (when reduction != "none") + # * Total: 4 * BLOCK_Q elements or 5 * BLOCK_Q elements when reduction != "none" + # - Since loss_sum is not necessarily used in every calculation, + # - and considering the consumption of other shared memory and the potential memory consumption of the HAS_LABEL loop. + # - Conservative estimate: 5 * BLOCK_Q * dtype_size * 8 bits + # - For safety, use: memory_multiplier=5.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits + # - shapes: ((V,),) + # - tiling_dims: (0,) means first dimension of each shape can be tiled + # - Returns: ((block_size,), + shapes = ((V,),) + tile_shapes = compute_default_tiling_strategy( + safety_margin=0.80, + # In the TVD calculation, many data are implicitly converted to f32, so the size of f32 can be directly used. + dtype_size=4, + memory_multiplier=5.0, + shapes=shapes, + tiling_dims=(0,), + ) + + if tile_shapes is not None and len(tile_shapes) > 0 and len(tile_shapes[0]) > 0: + # Strategy returns ((block_size,),) + BLOCK_SIZE = tile_shapes[0][0] + else: + # Fallback to desired block size if no best practice found (no tiling needed) + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + num_cores = get_npu_core_count() + grid = (min(num_cores, BT),) + + out_size = (BT, V) if reduction == "none" else (BT,) + + # The loss and grid accumulation on BF16 platform of NPU will have precision errors. + output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32) + grads = torch.empty_like(p, dtype=torch.float32) + + n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT + + _tv_distance_kernel[grid]( + p, + p.stride(0), + q, + q.stride(0), + output_tensor, + output_tensor.stride(0), + grads, + grads.stride(0), + shift_labels if has_label else torch.empty(1, device=p.device), + ignore_index, + V, + BT, + BLOCK_SIZE=BLOCK_SIZE, + HAS_LABEL=has_label, + reduction=reduction, + ) + + if reduction == "batchmean": + return output_tensor.sum() / n_non_ignore, grads / n_non_ignore + elif reduction == "sum": + return output_tensor.sum(dim=0), grads + elif reduction == "mean": + return output_tensor.sum() / (n_non_ignore * V), grads / (n_non_ignore * V) + else: + return output_tensor, grads + + +def tvd_backward_triton(grad_output, grads): + # If this is the last layer, grad_output is 1.0. Skip the mul then. + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + return grads + + return grads * grad_output + + +class LigerTVDLossFunction(torch.autograd.Function): + """ + Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton. + """ + + @staticmethod + @ensure_contiguous + def forward( + ctx, + p: torch.Tensor, + q: torch.Tensor, + shift_labels: Optional[torch.Tensor] = None, + reduction: REDUCTION_LITERAL = "batchmean", + ignore_index: int = -100, + ) -> torch.Tensor: + """A forward pass for the Total Variation Distance Loss. + + Args: + ctx: Torch autograd context + p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution. + q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution. + shift_labels (Optional[torch.Tensor]): A tensor of shape (BT,) containing the labels. + reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean". + ignore_index (int, optional): The index to ignore during loss calculation. Defaults to -100. + + Returns: + torch.Tensor: The computed Total Variation Distance Loss. + """ + has_label = False + if shift_labels is not None: + assert shift_labels.shape == (p.shape[0],), ( + f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}" + ) + shift_labels = shift_labels.contiguous() + has_label = True + + loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label) + ctx.save_for_backward(grads) + return loss + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + """A backward pass for the Total Variation Distance Loss. + + Args: + ctx: Torch autograd context + grad_output (torch.Tensor): The gradient of the loss with respect to the output. + + Returns: + tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs. + """ + (grads,) = ctx.saved_tensors + grads = tvd_backward_triton(grad_output, grads) + + return grads, None, None, None, None diff --git a/src/liger_kernel/ops/backends/_ascend/ub_manager.py b/src/liger_kernel/ops/backends/_ascend/ub_manager.py new file mode 100755 index 0000000000000000000000000000000000000000..0873ab619538c4bb33b924219f22fed7d10f1ec1 --- /dev/null +++ b/src/liger_kernel/ops/backends/_ascend/ub_manager.py @@ -0,0 +1,373 @@ +""" +Unified Buffer (UB) Manager for Ascend NPU. + +This module provides UB capacity detection and tiling strategy computation +for running Triton kernels on Ascend NPU. It automatically calculates +optimal block sizes based on UB capacity constraints to prevent UB overflow. +""" + +import os + +from typing import Optional +from typing import Tuple +from typing import Union + +import torch +import triton + +from liger_kernel.utils import is_npu_available + + +def _normalize_tiling_dims(tiling_dim: Union[int, Tuple[int, ...]]) -> set: + """ + Normalize tiling dimension specification to a set of dimension indices. + + Args: + tiling_dim: Either an int (single dimension) or tuple of ints (multiple dimensions). + + Returns: + Set of dimension indices that can be tiled. + """ + if isinstance(tiling_dim, int): + return {tiling_dim} + elif isinstance(tiling_dim, tuple): + return set(tiling_dim) + else: + return set() + + +def _default_strategy( + ub_capacity_bits: int, + safety_margin: float, + dtype_size: int, + memory_multiplier: float, + shapes: Tuple[Tuple[int, ...], ...], + tiling_dims: Tuple[Union[int, Tuple[int, ...]], ...], +) -> Tuple[int, ...]: + """ + Default tiling strategy: calculate maximum safe block size based on UB capacity. + + This is a unified strategy function that works for all kernels by abstracting + the memory calculation as: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 bits + + Args: + ub_capacity_bits: UB capacity in bits + safety_margin: Safety margin as a float (e.g., 0.80 for 80%) + dtype_size: Size of data type in bytes (e.g., 2 for float16, 4 for float32) + memory_multiplier: Memory multiplier for estimating peak memory usage + shapes: Tuple of full shapes. Each shape is a tuple of dimension sizes. + - For ROPE: ((n_q_head, hd), (n_kv_head, hd)) + - For GEGLU: ((n_cols,),) + tiling_dims: Tuple specifying which dimensions can be tiled for each shape. + Each element can be: + - int: single dimension index (e.g., 0 for first dimension) + - tuple of ints: multiple dimensions that can be tiled together + - For ROPE: (0, 0) means first dimension of each shape can be tiled + - For GEGLU: (0,) means first dimension of the shape can be tiled + Length must match len(shapes). + + Returns: + Tuple of maximum safe block sizes, one for each shape. + Each element is a power of 2. + + Note: + For each shape, fixed dimensions (non-tiling) are multiplied together to get unit_param. + The final block size is computed in compute_default_tiling_strategy by taking + min(desired_block_size, max_safe_block_size) where desired_block_size = triton.next_power_of_2(original_dim). + """ + if not shapes or not tiling_dims: + return () + + # Calculate max_safe_block_size for each tiling dimension + max_safe_sizes = [] + + for shape, tiling_dim in zip(shapes, tiling_dims): + # Normalize tiling_dim to a set of dimension indices + tiling_dim_set = _normalize_tiling_dims(tiling_dim) + + # Validate tiling dimensions are within shape bounds + if not tiling_dim_set: + raise ValueError( + f"Invalid tiling_dim: {tiling_dim}. tiling_dim must be an int or a non-empty tuple of ints." + ) + if any(dim_idx < 0 or dim_idx >= len(shape) for dim_idx in tiling_dim_set): + raise ValueError( + f"Invalid tiling_dim: {tiling_dim} for shape {shape}. " + f"All dimension indices must be in range [0, {len(shape)})." + ) + + # Calculate unit_param: product of fixed (non-tiling) dimensions + unit_param = 1.0 + for dim_idx, dim_size in enumerate(shape): + if dim_idx not in tiling_dim_set: + if dim_size <= 0: + # Invalid dimension size, use conservative default + unit_param = 1.0 + break + unit_param *= float(dim_size) + + # Ensure unit_param is at least 1.0 + if unit_param <= 0: + unit_param = 1.0 + + # Calculate maximum safe block size based on UB capacity + # Memory: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 bits + SAFE_UB_CAPACITY_BITS = int(ub_capacity_bits * safety_margin) + + # Solve: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 <= SAFE_UB_CAPACITY_BITS + # BLOCK_SIZE <= SAFE_UB_CAPACITY_BITS / (memory_multiplier * unit_param * dtype_size * 8) + max_block_size = int(SAFE_UB_CAPACITY_BITS // (memory_multiplier * unit_param * dtype_size * 8)) + max_block_size = max(1, max_block_size) + + # Find largest power of 2 <= max_block_size + # Use triton.next_power_of_2(max_block_size + 1) // 2 to get the largest power of 2 <= max_block_size + safe_block_size = triton.next_power_of_2(max_block_size + 1) // 2 + max_safe_sizes.append(safe_block_size) + + return tuple(max_safe_sizes) + + +class UBManager: + """ + Unified Buffer Manager for Ascend NPU. + + Provides UB capacity detection and management for Ascend NPU devices. + The UB capacity is used by tiling strategy functions to calculate optimal block sizes. + """ + + def __init__(self, ub_capacity_bits: Optional[int] = None): + """ + Initialize UB Manager. + + Args: + ub_capacity_bits: UB capacity in bits. If None, will be detected automatically. + """ + self._npu_model = self._detect_npu_model() + self._ub_capacity_bits = ub_capacity_bits or self._detect_ub_capacity() + + @property + def ub_capacity_bits(self) -> int: + """Get UB capacity in bits.""" + return self._ub_capacity_bits + + @property + def ub_capacity_bytes(self) -> int: + """Get UB capacity in bytes.""" + return self._ub_capacity_bits // 8 + + @property + def npu_model(self) -> str: + """Get detected NPU model name.""" + return self._npu_model + + def _detect_npu_model(self) -> str: + """Detect NPU model from device properties.""" + if not is_npu_available(): + return "unknown" + + try: + dev_props = torch.npu.get_device_properties(0) + # Try to get model name from device properties + return dev_props.name + except Exception: + pass + + return "default" + + def _detect_ub_capacity(self) -> int: + """ + Detect UB capacity from environment variable or get_soc_spec. + + Returns: + UB capacity in bits. + + Raises: + RuntimeError: If UB capacity cannot be detected and no environment variable is set. + """ + # Check environment variable first (in bits) + env_capacity = os.getenv("ASCEND_UB_CAPACITY_BITS") + if env_capacity is not None: + try: + capacity_bits = int(env_capacity) + if capacity_bits > 0: + return capacity_bits + except ValueError: + pass + + # Try to get from get_soc_spec (returns bytes, convert to bits) + if is_npu_available(): + try: + from tbe.common.platform import get_soc_spec + from tbe.common.platform import set_current_compile_soc_info + + # Set current SOC info for get_soc_spec to work correctly + device = getattr(torch, "npu") + soc_info = device.get_device_name(device.current_device()) + set_current_compile_soc_info(soc_info) + + # Query UB size (get_soc_spec returns size in bytes) + ub_size_bytes = get_soc_spec("UB_SIZE") + + if ub_size_bytes is None or ub_size_bytes <= 0: + raise ValueError(f"Invalid UB_SIZE from get_soc_spec: {ub_size_bytes}") + + # Convert bytes to bits + ub_capacity_bits = ub_size_bytes * 8 + return ub_capacity_bits + + except ImportError: + raise RuntimeError( + "Cannot import tbe.common.platform.get_soc_spec. " + "Please ensure CANN environment variables are sourced " + "(e.g., source /usr/local/Ascend/ascend-toolkit/set_env.sh)" + ) + except Exception as e: + raise RuntimeError( + f"Failed to detect UB capacity from get_soc_spec: {e}. " + "Please set ASCEND_UB_CAPACITY_BITS environment variable as fallback." + ) + + # If NPU is not available, raise error + raise RuntimeError( + "NPU is not available and UB capacity cannot be detected. " + "Please set ASCEND_UB_CAPACITY_BITS environment variable." + ) + + +# Global singleton instance +_ub_manager: Optional[UBManager] = None + + +def get_ub_manager() -> UBManager: + """Get global UB manager instance.""" + global _ub_manager + if _ub_manager is None: + _ub_manager = UBManager() + return _ub_manager + + +def compute_default_tiling_strategy( + safety_margin: float = 0.80, + dtype_size: Optional[int] = None, + memory_multiplier: Optional[float] = None, + shapes: Optional[Tuple[Tuple[int, ...], ...]] = None, + tiling_dims: Optional[Tuple[Union[int, Tuple[int, ...]], ...]] = None, +) -> Optional[Tuple[Tuple[int, ...], ...]]: + """ + Compute tiling strategy using the default strategy function. + + This function directly calls the default strategy and computes the final + tiling result. All kernels use the same unified strategy function, so + there's no need for kernel_name-based lookup. + + Args: + safety_margin: Safety margin as a float (e.g., 0.80 for 80%). Default is 0.80. + dtype_size: Size of data type in bytes (e.g., 2 for float16, 4 for float32). + Must be provided. If None or <= 0, defaults to 4 (float32). + memory_multiplier: Memory multiplier for estimating peak memory usage. + - For GEGLU: typically 10.0 for backward, 4.0 for forward + - For ROPE: typically 3.0 + If None, defaults to 10.0 (conservative estimate). + shapes: Tuple of full shapes. Each shape is a tuple of dimension sizes. + - For ROPE: ((n_q_head, hd), (n_kv_head, hd)) + - For GEGLU: ((n_cols,),) + Can pass original shapes (will handle padding internally) or padded shapes. + tiling_dims: Tuple specifying which dimensions can be tiled for each shape. + Each element can be: + - int: single dimension index (e.g., 0 for first dimension) + - tuple of ints: multiple dimensions that can be tiled together + - For ROPE: (0, 0) means first dimension of each shape can be tiled + - For GEGLU: (0,) means first dimension of the shape can be tiled + Length must match len(shapes). Cannot be empty. + + Returns: + Tuple of tiled shapes with same structure as input shapes. + Tiling dimensions are replaced with computed block sizes (power of 2), + while non-tiling dimensions are padded to next power of 2. + - For ROPE: ((block_size_q, pad_hd), (block_size_kv, pad_hd)) + - For GEGLU: ((block_size,),) + Returns None if shapes or tiling_dims is None or empty. + + Examples: + >>> # ROPE forward + >>> strategy = compute_default_tiling_strategy( + ... safety_margin=0.90, + ... dtype_size=4, + ... memory_multiplier=3.0, + ... shapes=((32, 128), (32, 128)), + ... tiling_dims=(0, 0) + ... ) + >>> # Returns: ((block_size_q, 128), (block_size_kv, 128)) + >>> # GEGLU forward + >>> strategy = compute_default_tiling_strategy( + ... safety_margin=0.80, + ... dtype_size=2, + ... memory_multiplier=7.0, + ... shapes=((4096,),), + ... tiling_dims=(0,) + ... ) + >>> # Returns: ((block_size,),) + """ + ub_manager = get_ub_manager() + + if shapes is None or not shapes or tiling_dims is None or not tiling_dims: + return None + + if len(shapes) != len(tiling_dims): + return None + + if dtype_size is None or dtype_size <= 0: + dtype_size = 4 # Default to float32 + + if memory_multiplier is None or memory_multiplier <= 0: + memory_multiplier = 10.0 # Default conservative estimate + + # Call strategy to get max_safe_block_size for each shape + max_supported = _default_strategy( + ub_manager.ub_capacity_bits, + safety_margin, + dtype_size, + memory_multiplier, + shapes, + tiling_dims, + ) + + if not max_supported or len(max_supported) != len(shapes): + return None + + # Build result: same structure as shapes, with tiling dims replaced by computed block sizes + result = [] + for shape, tiling_dim, max_safe in zip(shapes, tiling_dims, max_supported): + result_shape = list(shape) + + # Normalize tiling_dim to a set of dimension indices + tiling_dim_set = _normalize_tiling_dims(tiling_dim) + + # Validate tiling dimensions are within shape bounds + if not tiling_dim_set: + raise ValueError( + f"Invalid tiling_dim: {tiling_dim}. tiling_dim must be an int or a non-empty tuple of ints." + ) + if any(dim_idx < 0 or dim_idx >= len(result_shape) for dim_idx in tiling_dim_set): + raise ValueError( + f"Invalid tiling_dim: {tiling_dim} for shape {shape}. " + f"All dimension indices must be in range [0, {len(result_shape)})." + ) + + # Replace tiling dimensions with computed block sizes + # For each tiling dimension, compute: min(desired, max_safe) + for dim_idx in tiling_dim_set: + original_dim = result_shape[dim_idx] + desired = triton.next_power_of_2(original_dim) + final_val = min(desired, max_safe) + final_val = max(1, final_val) # Ensure at least 1 + result_shape[dim_idx] = final_val + + # Pad non-tiling dimensions to next power of 2 + for dim_idx, dim_size in enumerate(result_shape): + if dim_idx not in tiling_dim_set: + result_shape[dim_idx] = triton.next_power_of_2(dim_size) + + result.append(tuple(result_shape)) + + return tuple(result) diff --git a/src/liger_kernel/ops/backends/registry.py b/src/liger_kernel/ops/backends/registry.py new file mode 100755 index 0000000000000000000000000000000000000000..5fe3613c82304d33e20d68b536823edc2c9d152e --- /dev/null +++ b/src/liger_kernel/ops/backends/registry.py @@ -0,0 +1,61 @@ +""" +Vendor registry for Liger-Kernel multi-backend support. + +This module defines VendorInfo and the registry for vendor registration. +Each vendor registers itself by calling register_vendor() in its __init__.py. +""" + +from dataclasses import dataclass +from typing import Optional + +# Dynamically get backends package path to avoid hardcoding +_BACKENDS_PACKAGE = __name__.rsplit(".", 1)[0] # "liger_kernel.ops.backends" + + +@dataclass +class VendorInfo: + """ + Information about a chip vendor and its supported device. + + Attributes: + vendor: Vendor name (e.g., "ascend", "intel", "nvidia") + device: Device type this vendor supports (e.g., "npu", "xpu") + """ + + vendor: str + device: str + + @property + def module_path(self) -> str: + """Auto-generated module path based on vendor name.""" + return f"{_BACKENDS_PACKAGE}._{self.vendor}.ops" + + +# Registry mapping device types to their vendor info +# Vendors register themselves via register_vendor() +VENDOR_REGISTRY: dict[str, VendorInfo] = {} + + +def register_vendor(vendor_info: VendorInfo) -> None: + """ + Register a vendor's info in the global registry. + + This should be called in each vendor's __init__.py to register itself. + + Args: + vendor_info: VendorInfo instance to register + """ + VENDOR_REGISTRY[vendor_info.device] = vendor_info + + +def get_vendor_for_device(device: str) -> Optional[VendorInfo]: + """ + Get the VendorInfo for a given device type. + + Args: + device: Device type (e.g., "npu", "xpu") + + Returns: + VendorInfo if found, None otherwise + """ + return VENDOR_REGISTRY.get(device) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py new file mode 100755 index 0000000000000000000000000000000000000000..4793f75c9d14e25e1503908f8deb3879c812f010 --- /dev/null +++ b/src/liger_kernel/ops/cross_entropy.py @@ -0,0 +1,558 @@ +import operator + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.utils import compare_version +from liger_kernel.ops.utils import element_mul_kernel +from liger_kernel.ops.utils import is_hip +from liger_kernel.utils import infer_device +from liger_kernel.utils import is_npu_available + +if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available(): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import tanh + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import tanh +else: + from triton.language.math import tanh + + +@triton.jit +def liger_cross_entropy_kernel( + X_ptr, + X_stride, + Y_ptr, + Y_stride, + weight_ptr, + loss_ptr, + z_loss_ptr, + loss_stride, + token_accuracy_ptr, + token_accuracy_stride, + predicted_tokens_ptr, + predicted_tokens_stride, + n_cols, + n_non_ignore, + sum_non_ignore_weight, + weight_sum, + ignore_index, + lse_square_scale: tl.constexpr, + label_smoothing: tl.constexpr, + reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time + softcap, + RETURN_Z_LOSS: tl.constexpr, + RETURN_TOKEN_ACCURACY: tl.constexpr, + RETURN_PREDICTED_TOKENS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_SOFTCAPPING: tl.constexpr, + HAS_GRADIENTS: tl.constexpr, +): + """ + This kernel computes both cross entropy loss and the gradient of the input. + We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math. + + Parameters: + X_ptr: Pointer to input tensor. + X_stride (int): The stride of the input tensor. + Y_ptr: Pointer to target tensor. + Y_stride (int): The stride of the target tensor. + weight_ptr: Pointer to weight tensor. + loss_ptr: Pointer to tensor to store the loss. + z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0. + loss_stride (int): The stride of the loss tensor. + token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0. + token_accuracy_stride (int): The stride of the token accuracy tensor. + n_cols (int): The number of columns in the input tensor. + n_non_ignore (float): The number of non-ignored elements in the batch. + sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch. + weight_sum (float): The sum of weight tensor. + ignore_index (int): The index to ignore in the target. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. + reduction (str): The string for the reduction to apply + softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap). + RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1. + RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1. + BLOCK_SIZE (int): The block size for Triton operations. + HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes. + HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not. + HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass. + """ + + # https://github.com/triton-lang/triton/issues/1058 + # If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64 + program_id = tl.program_id(0).to(tl.int64) + + # 1. Load Y_ptr first because if the target is ignore_index, we can return right away + Y_ptr += program_id * Y_stride + y = tl.load(Y_ptr) + + # 2. locate the start index + X_ptr += program_id * X_stride + + if y == ignore_index: + # set all X_ptr as 0 + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols) + # For ignored tokens, set token accuracy to 0 + if RETURN_TOKEN_ACCURACY: + token_accuracy_ptr += program_id * token_accuracy_stride + tl.store(token_accuracy_ptr, 0.0) + if RETURN_PREDICTED_TOKENS: + predicted_tokens_ptr += program_id * predicted_tokens_stride + tl.store(predicted_tokens_ptr, -1) + return + + loss_ptr += program_id * loss_stride + if RETURN_Z_LOSS: + z_loss_ptr += program_id * loss_stride + if RETURN_TOKEN_ACCURACY: + token_accuracy_ptr += program_id * token_accuracy_stride + if RETURN_PREDICTED_TOKENS: + predicted_tokens_ptr += program_id * predicted_tokens_stride + + if HAS_WEIGHT: + weight_y = tl.load(weight_ptr + y).cast(tl.float32) + + # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax) + # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867 + + # 3. [Online softmax] first pass: find max + sum + m = float("-inf") # m is the max value. use the notation from the paper + d = 0.0 # d is the sum. use the notation from the paper + argmax_idx = 0 # Track the index of the maximum value for token accuracy / predicted tokens computation + ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation + if HAS_SOFTCAPPING: + ori_X_y = softcap * tanh(ori_X_y / softcap) + + # Label smoothing is a general case of normal cross entropy + # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310 + scaled_x_sum = 0.0 + eps = label_smoothing / n_cols + + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load( + X_ptr + X_offsets, + mask=X_offsets < n_cols, + other=float("-inf"), + # Ensure float32 precision for softmax calculation + ).cast(tl.float32) + if HAS_SOFTCAPPING: + X_block = softcap * tanh(X_block / softcap) + block_max = tl.max(X_block) + + # Track argmax for accuracy / predicted tokens computation + if RETURN_TOKEN_ACCURACY or RETURN_PREDICTED_TOKENS: + # Find the index of the maximum value in this block + is_max_mask = X_block == block_max + # Mask out invalid indices with a value larger than n_cols + masked_offsets = tl.where(is_max_mask, X_offsets, n_cols) + # Get the first (smallest) index where max occurs + current_block_argmax_idx = tl.min(masked_offsets) + + is_new_max = block_max > m + argmax_idx = tl.where(is_new_max, current_block_argmax_idx, argmax_idx) + + if label_smoothing > 0: + # scale X beforehand to avoid overflow + if HAS_WEIGHT: + weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0)) + else: + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) + m_new = tl.maximum(m, block_max) + d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) + m = m_new + + # log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X))))) + # = log (e^(max(X)) * sum(e ^ (X_i - max(X)))) + # = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d + lse = m + tl.log(d) + + # 4. [Online Softmax] Second pass: compute gradients + # For 'mean' reduction, gradients are normalized by number of non-ignored elements (N) + # dx_y = (softmax(x_y) - 1) / N + # dx_i = softmax(x_i) / N, i != y + # For label smoothing: + # dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y + # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N + # = dx_i - (1 - label_smoothing) / N + # With Z loss: + # dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y + # dx_y = dx_i - (1 - label_smoothing) / N + # For 'sum' reduction, no normalization is applied: + # dx_y = softmax(x_y) - 1 + # dx_i = softmax(x_i), for i ≠ y + if HAS_GRADIENTS: + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load( + X_ptr + X_offsets, + mask=X_offsets < n_cols, + other=float("-inf"), + # Ensure float32 precision for softmax calculation + ).cast(tl.float32) + if HAS_SOFTCAPPING: + intermediate = tanh(X_block / softcap) + X_block = softcap * intermediate + + if not HAS_WEIGHT: + # softmax(x_i) + X_block = tl.exp(X_block - m) / d + # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) + X_block += 2 * lse_square_scale * lse * X_block + # smoothing term + X_block += -eps + # special handle dx_y + X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing)) + # reduction scale + if reduction == "mean": + X_block = X_block / n_non_ignore + else: + weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) + softmax_X = tl.exp(X_block - m) / d + # derivative of original_loss + dloss_ori = (1 - label_smoothing) * softmax_X + # specially handle dx_y + dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing)) + dloss_ori = dloss_ori * weight_y + # derivative of smooth_loss + dloss_smooth = eps * (-weight_block + softmax_X * weight_sum) + # derivative of z-loss + dz_loss = 2 * lse_square_scale * lse * softmax_X + # reduction scale + if reduction == "mean": + dloss_ori = dloss_ori / sum_non_ignore_weight + dloss_smooth = dloss_smooth / sum_non_ignore_weight + # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. + dz_loss = dz_loss / n_non_ignore + # derivative of total_loss + X_block = dloss_ori + dloss_smooth + dz_loss + + # chain rule softcapping + # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap)) + if HAS_SOFTCAPPING: + X_block = X_block * (1 - intermediate * intermediate) + + tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols) + + # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in + # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34 + tl.debug_barrier() + + # 5. Calculate the loss + + # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) + # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) + # = X_y - m - log d = X_y - lse + # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1 + # So we can safely calculate log (softmax(X_y)) without overflow + loss = lse - ori_X_y + if HAS_WEIGHT: + loss = weight_y * loss + + # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps + # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) + # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) + # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: + # = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd)) + # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 + # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 + # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087 + if label_smoothing > 0: + if HAS_WEIGHT: + smooth_loss = scaled_x_sum + eps * lse * weight_sum + else: + smooth_loss = scaled_x_sum + label_smoothing * lse + loss = loss * (1 - label_smoothing) + smooth_loss + + # An auxiliary loss, z_loss + # Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html + z_loss = lse_square_scale * lse * lse + # Normalize the loss by the number of non-ignored elements if reduction is "mean" + if reduction == "mean": + if HAS_WEIGHT: + loss = loss / sum_non_ignore_weight + else: + loss = loss / n_non_ignore + # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. + z_loss = z_loss / n_non_ignore + loss += z_loss + + tl.store(loss_ptr, loss) + if RETURN_Z_LOSS: + tl.store(z_loss_ptr, z_loss) + if RETURN_TOKEN_ACCURACY: + # Store 1.0 if prediction is correct, 0.0 otherwise + is_correct = 1.0 if argmax_idx == y else 0.0 + tl.store(token_accuracy_ptr, is_correct) + if RETURN_PREDICTED_TOKENS: + tl.store(predicted_tokens_ptr, argmax_idx) + + +# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 +# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling +# The optimal maximum block size depends on your hardware, your kernel, and your dtype +# the best size we found by manually tuning on xpu and npu. +if infer_device() == "xpu": + MAX_FUSED_SIZE = 4096 +elif infer_device() == "npu": + MAX_FUSED_SIZE = 2048 +else: + MAX_FUSED_SIZE = 65536 // 2 + + +def cross_entropy_forward( + _input, + target, + weight, + ignore_index, + lse_square_scale, + label_smoothing, + reduction, + softcap, + return_z_loss, + return_token_accuracy=False, + return_predicted_tokens=False, +): + assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}" + assert isinstance(return_token_accuracy, bool), ( + f"return_token_accuracy must be True or False. Got: {return_token_accuracy}" + ) + assert isinstance(return_predicted_tokens, bool), ( + f"return_predicted_tokens must be True or False. Got: {return_predicted_tokens}" + ) + + BT, V = _input.shape + n_rows = BT + + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + # unreduced loss + loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) + z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None + token_accuracy_1d = ( + torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None + ) + predicted_tokens_1d = ( + torch.full((n_rows,), -1, dtype=torch.int64, device=_input.device) if return_predicted_tokens else None + ) + + target_mask = target != ignore_index + n_non_ignore = target_mask.sum().item() + assert (target * target_mask).max() < _input.shape[-1], ( + f"Target {target.max()} is out of bounds. Expected < {_input.shape[-1]}" + ) + assert (target * target_mask).min() >= 0, f"Target {target.min()} is out of bounds. Expected >= 0" + sum_non_ignore_weight = n_non_ignore + weight_sum = 0.0 + if weight is not None: + assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}" + assert torch.is_floating_point(weight), ( + f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}" + ) + sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item() + weight_sum = weight.sum().item() + # ensure weight is contiguous + if weight.stride(-1) != 1: + weight = weight.contiguous() + + # ensure _input and target are contiguous in the last dimension + if _input.stride(-1) != 1: + _input = _input.contiguous() + if target.stride(-1) != 1: + target = target.contiguous() + + # Here we use a trick to store X_ptr gradient in X_ptr so we can save memory + liger_cross_entropy_kernel[(n_rows,)]( + X_ptr=_input, + X_stride=_input.stride(-2), + Y_ptr=target, + Y_stride=target.stride(-1), # always 1 + weight_ptr=weight, # dummy if None + loss_ptr=loss_1d, + z_loss_ptr=z_loss_1d, + loss_stride=loss_1d.stride(-1), # always 1 + token_accuracy_ptr=token_accuracy_1d, + token_accuracy_stride=token_accuracy_1d.stride(-1) + if return_token_accuracy + else 0, # always 1 if accuracy is enabled + predicted_tokens_ptr=predicted_tokens_1d, + predicted_tokens_stride=predicted_tokens_1d.stride(-1) + if return_predicted_tokens + else 0, # always 1 if predicted tokens is enabled + n_cols=V, + n_non_ignore=n_non_ignore, + sum_non_ignore_weight=sum_non_ignore_weight, + ignore_index=ignore_index, + weight_sum=weight_sum, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + reduction=reduction, + softcap=softcap, + RETURN_Z_LOSS=return_z_loss, + RETURN_TOKEN_ACCURACY=return_token_accuracy, + RETURN_PREDICTED_TOKENS=return_predicted_tokens, + BLOCK_SIZE=BLOCK_SIZE, + HAS_WEIGHT=True if weight is not None else False, + HAS_SOFTCAPPING=True if softcap is not None else False, + HAS_GRADIENTS=_input.requires_grad, + # TODO: 32 seems to give the best performance + # Performance is quite sensitive to num_warps + num_warps=32 if not is_hip() else 16, + ) + + if reduction == "none": + loss = loss_1d + z_loss = z_loss_1d if return_z_loss else None + token_accuracy = token_accuracy_1d if return_token_accuracy else None + else: + loss = torch.sum(loss_1d) + z_loss = torch.sum(z_loss_1d) if return_z_loss else None + # For accuracy, we compute the mean across all non-ignored tokens + token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None + + predicted_tokens = predicted_tokens_1d if return_predicted_tokens else None + + return loss, z_loss, token_accuracy, predicted_tokens, _input + + +def cross_entropy_backward(_input, grad_output): + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + pass + # If reduction is 'none' + elif grad_output.ndim > 0: + _input = _input * grad_output.unsqueeze(dim=1) + # If reduction is ['mean', 'sum'], grad_output is just a scalar + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + else: + BT, V = _input.shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + element_mul_kernel[(n_rows,)]( + _input, + _input.stride(-2), + grad_output, + V, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + + return _input + + +class LigerCrossEntropyFunction(torch.autograd.Function): + """ + This class implements a custom autograd function for the Liger Cross Entropy loss. + It overrides the forward and backward methods of the torch.autograd.Function class. + """ + + @staticmethod + def forward( + ctx, + _input: torch.Tensor, + target: torch.Tensor, + weight: Optional[torch.FloatTensor], + ignore_index: int = -100, + lse_square_scale: float = 0.0, + label_smoothing: float = 0.0, + reduction: str = "mean", + softcap: Optional[float] = None, + return_z_loss: bool = False, + return_token_accuracy: bool = False, + return_predicted_tokens: bool = False, + ): + """ + The forward pass of the Liger Cross Entropy loss. + + Parameters: + ctx : The context object. + _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size. + target (tensor): The target tensor of shape (BT) where each value is in [0, V-1]. + weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype + ignore_index (int): The index to ignore in the target. + lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + reduction (str): The reduction to apply to the output: "none" | "mean | "sum". + softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap). + return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy, predicted_tokens) instead of (loss, None, None, None). Default: `False` + return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False` + return_predicted_tokens (bool): When `return_predicted_tokens` is `True`, returns per-token predicted class indices (argmax) without materializing logits. Default: `False` + + Returns: + tuple: A tuple with the computed losses, accuracy, and predicted tokens: (loss, z_loss, token_accuracy, predicted_tokens). z_loss, token_accuracy, and predicted_tokens are None if not requested. + """ + input_requires_grad = _input.requires_grad + + loss, z_loss, token_accuracy, predicted_tokens, _input = cross_entropy_forward( + _input, + target, + weight, + ignore_index, + lse_square_scale, + label_smoothing, + reduction, + softcap, + return_z_loss, + return_token_accuracy, + return_predicted_tokens, + ) + # TODO: investigation + # If we don't detach the _input tensor, the memory will double + # Not sure why but seems that there will be a time both grad and value exist but in different location + if input_requires_grad: + ctx.save_for_backward(_input.detach()) + ctx.return_z_loss = return_z_loss + ctx.return_token_accuracy = return_token_accuracy + ctx.return_predicted_tokens = return_predicted_tokens + + return loss, z_loss, token_accuracy, predicted_tokens + + @staticmethod + def backward(ctx, grad_output, grad_output2, grad_output3, grad_output4): + """ + The backward pass of the Liger Cross Entropy loss. + + Parameters: + ctx : The context object with saved tensors. + grad_output (tensor): The tensor containing the gradient of the loss with respect to the output. + grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging). + grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics). + grad_output4 (tensor): No use. Gradient for predicted_tokens (not used as predicted_tokens is only for metrics). + Returns: + tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. + """ + if ctx.return_z_loss: + del grad_output2 # z_loss is only for logging + if ctx.return_token_accuracy: + del grad_output3 # token_accuracy is only for metrics + if ctx.return_predicted_tokens: + del grad_output4 # predicted_tokens is only for metrics + + (_input,) = ctx.saved_tensors + _input = cross_entropy_backward(_input, grad_output) + return ( + _input, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) diff --git a/src/liger_kernel/ops/dyt.py b/src/liger_kernel/ops/dyt.py new file mode 100755 index 0000000000000000000000000000000000000000..432c0ee275681150c8915906ab7c0d334405ae7c --- /dev/null +++ b/src/liger_kernel/ops/dyt.py @@ -0,0 +1,160 @@ +import operator + +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.utils import compare_version +from liger_kernel.ops.utils import ensure_contiguous +from liger_kernel.ops.utils import get_npu_core_count +from liger_kernel.ops.utils import infer_device +from liger_kernel.utils import is_npu_available + +if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available(): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import tanh + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import tanh +else: + from triton.language.math import tanh + + +# @triton.autotune([triton.Config({"BLOCK_N":bn}, num_stages=ns, num_warps=nw) +# for bn in [1024, 2048, 4096] +# for ns in [1,2,4] +# for nw in [4, 8, 16, 32] +# ], +# key=['N']) +@triton.jit +def _dyt_fwd_kernel(X, Y, Alpha, Gamma, Beta, HAVE_BETA: tl.constexpr, N: tl.constexpr, BLOCK_N: tl.constexpr = 1024): + col = tl.cast(tl.program_id(0), tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N) + mask = col < N + row_id = tl.cast(tl.program_id(1), tl.int64) + + X += row_id * N + Y += row_id * N + alpha = tl.load(Alpha).to(tl.float32) + + gamma = tl.load(Gamma + col, mask=mask, other=0.0).to(tl.float32) + + x = tl.load(X + col, mask=mask, other=0.0).to(tl.float32) + + tanh_x = tanh(alpha * x) + y = tanh_x * gamma + if HAVE_BETA: + beta = tl.load(Beta + col, mask=mask, other=0.0).to(tl.float32) + y += beta + tl.store(Y + col, y, mask=mask) + + +# @triton.autotune([triton.Config({"BLOCK_N":bn}, num_stages=ns, num_warps=nw) +# for bn in [1024, 2048, 4096] +# for ns in [1,2,4] +# for nw in [4, 8, 16] +# ], +# key=['N']) +@triton.jit +def _dyt_bwd_kernel( + DY, DX, DA, DG, DB, X, Alpha, Gamma, HAVE_BETA: tl.constexpr, M, N: tl.constexpr, BLOCK_N: tl.constexpr = 1024 +): + col = tl.cast(tl.program_id(0), tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N) + mask = col < N + start_row_id = tl.cast(tl.program_id(1), tl.int64) + + alpha = tl.load(Alpha).to(tl.float32) + da = 0.0 + gamma = tl.load(Gamma + col, mask=mask, other=0.0).to(tl.float32) + dg = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAVE_BETA: + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + for row_id in range(start_row_id, M, tl.num_programs(1)): + x = tl.load(X + row_id * N + col, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(DY + row_id * N + col, mask=mask, other=0.0).to(tl.float32) + tanh_x = tanh(alpha * x) + if HAVE_BETA: + db += dy + dg += dy * tanh_x + tmp = (1 - tanh_x * tanh_x) * dy * gamma + da += tl.sum(x * tmp, 0) + dx = alpha * tmp + tl.store(DX + row_id * N + col, dx, mask=mask) + + tl.store(DG + start_row_id * N + col, dg, mask=mask) + if HAVE_BETA: + tl.store(DB + start_row_id * N + col, db, mask=mask) + tl.store(DA + start_row_id * tl.cdiv(N, 512) + tl.program_id(0), da) + + +def liger_dyt_fwd(x, alpha, gamma, beta): + assert x.is_contiguous() + HAVE_BETA = True if beta is not None else False + input_shape = x.shape + x = x.view(-1, input_shape[-1]) + M, N = x.shape + + y = torch.empty_like(x) + + if N >= 4096: + kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 2048), "num_warps": 4, "num_stages": 1} + else: + kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 1024), "num_warps": 4, "num_stages": 1} + + grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), M) + _dyt_fwd_kernel[(grid)]( + x, + y, + alpha, + gamma, + beta, + HAVE_BETA, + N, + **kwargs, + ) + return y.view(input_shape) + + +def liger_dyt_bwd(dy, x, alpha, gamma, beta): + assert dy.is_contiguous() + input_shape = x.shape + x = x.view(-1, input_shape[-1]) + M, N = x.shape + HAVE_BETA = True if beta is not None else False + + device = infer_device() + if device == "cuda": + NUM_SMS = torch.cuda.get_device_properties(x.device).multi_processor_count + elif device == "xpu": + NUM_SMS = torch.xpu.get_device_properties(x.device).gpu_subslice_count + elif device == "npu": + NUM_SMS = get_npu_core_count() + da = torch.zeros(NUM_SMS, triton.cdiv(N, 512), dtype=torch.float32, device=x.device) + dg = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device) + db = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None + dx = torch.empty_like(dy) + + kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 1024), "num_warps": 8, "num_stages": 2} + grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), NUM_SMS) + _dyt_bwd_kernel[grid](dy, dx, da, dg, db, x, alpha, gamma, HAVE_BETA, M, N, **kwargs) + if HAVE_BETA: + db = db.sum(0).to(x.dtype) + dg = dg.sum(0).to(gamma.dtype) + da = da.sum().to(x.dtype).unsqueeze(0) + return dx.view(input_shape), da, dg, db + + +class LigerDyTFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, x, alpha, gamma, beta): + y = liger_dyt_fwd(x, alpha, gamma, beta) + ctx.save_for_backward(x, alpha, gamma, beta) + return y + + @staticmethod + @ensure_contiguous + def backward(ctx, dy): + x, alpha, gamma, beta = ctx.saved_tensors + dx, dalpha, dgamma, dbeta = liger_dyt_bwd(dy, x, alpha, gamma, beta) + return dx, dalpha, dgamma, dbeta diff --git a/src/liger_kernel/ops/experimental/embedding.py b/src/liger_kernel/ops/experimental/embedding.py new file mode 100755 index 0000000000000000000000000000000000000000..159b9a66d64158332c37e763ca9763ac9ede1932 --- /dev/null +++ b/src/liger_kernel/ops/experimental/embedding.py @@ -0,0 +1,141 @@ +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.utils import ensure_contiguous + + +@triton.jit +def embedding_forward_kernel( + embeddings_ptr, + indices_ptr, + output_ptr, + n_elements, + embedding_dim: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M) + mask_m = offsets_m < n_elements + indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0) + offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N) + mask_n = offsets_n < embedding_dim + + embedding_offsets = indices[:, None] * embedding_dim + offsets_n[None, :] + embeddings = tl.load( + embeddings_ptr + embedding_offsets, + mask=mask_m[:, None] & mask_n[None, :], + other=0.0, + ) + + output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :] + tl.store(output_ptr + output_offsets, embeddings, mask=mask_m[:, None] & mask_n[None, :]) + + +@triton.jit +def embedding_backward_kernel( + grad_output_ptr, + grad_weight_ptr, + indices_ptr, + n_elements, + embedding_dim: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M) + mask_m = offsets_m < n_elements + indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0) + offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N) + mask_n = offsets_n < embedding_dim + + grad_output = tl.load( + grad_output_ptr + offsets_m[:, None] * embedding_dim + offsets_n[None, :], + mask=mask_m[:, None] & mask_n[None, :], + other=0.0, + ) + + grad_weight_offsets = indices[:, None] * embedding_dim + offsets_n[None, :] + + tl.atomic_add( + grad_weight_ptr + grad_weight_offsets, + grad_output, + mask=mask_m[:, None] & mask_n[None, :], + ) + + +class LigerEmbeddingFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, embeddings: torch.Tensor, indices: torch.Tensor): + ori_shape = indices.shape + indices = indices.view(-1) + output = torch.empty( + indices.shape[0], + embeddings.shape[1], + device=indices.device, + dtype=embeddings.dtype, + ) + + n_elements = indices.numel() + embedding_dim = embeddings.shape[1] + + BLOCK_SIZE_M = triton.next_power_of_2(min(128, embedding_dim)) + BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim)) + grid = ( + triton.cdiv(n_elements, BLOCK_SIZE_M), + triton.cdiv(embedding_dim, BLOCK_SIZE_N), + ) + + embedding_forward_kernel[grid]( + embeddings, + indices, + output, + n_elements, + embedding_dim=embedding_dim, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + ) + + ctx.save_for_backward(indices, embeddings) + + return output.view(*ori_shape, -1) + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output: torch.Tensor): + indices, embedding_table = ctx.saved_tensors + grad_output = grad_output.contiguous().view(-1, embedding_table.shape[1]) + + grad_weight = torch.zeros_like(embedding_table) + + n_elements = indices.numel() + embedding_dim = embedding_table.shape[1] + + BLOCK_SIZE_M = triton.next_power_of_2(min(128, embedding_dim)) + BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim)) + grid = ( + triton.cdiv(n_elements, BLOCK_SIZE_M), + triton.cdiv(embedding_dim, BLOCK_SIZE_N), + ) + + embedding_backward_kernel[grid]( + grad_output, + grad_weight, + indices, + n_elements, + embedding_dim=embedding_dim, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + ) + + return grad_weight, None diff --git a/src/liger_kernel/ops/experimental/mm_int8int2.py b/src/liger_kernel/ops/experimental/mm_int8int2.py new file mode 100755 index 0000000000000000000000000000000000000000..326d536326698174b944b4915b4443a670980ca1 --- /dev/null +++ b/src/liger_kernel/ops/experimental/mm_int8int2.py @@ -0,0 +1,349 @@ +import torch +import triton +import triton.language as tl + + +def unpack_weights(packed: torch.Tensor, bits: int = 2) -> torch.Tensor: + values_per_item = 8 // bits + packed_shape = packed.shape + + if len(packed_shape) == 1: + original_row_dim = packed_shape[0] * values_per_item + unpacked_shape = (original_row_dim,) + else: + original_row_dim = packed_shape[0] * values_per_item + unpacked_shape = (original_row_dim, *packed_shape[1:]) + + unpacked = torch.zeros(unpacked_shape, device=packed.device, dtype=torch.uint8) + + for i in range(values_per_item): + start = i * packed_shape[0] + end = start + packed_shape[0] + mask = 3 << (2 * i) + unpacked[start:end] = (packed & mask) >> (2 * i) + + unpacked = unpacked.to(torch.int32) - 1 + return unpacked + + +def pack_weights(intweights: torch.Tensor, bits: int = 2) -> torch.Tensor: + intweights += 1 + original_shape = intweights.shape + values_per_item = 8 // bits + row_dim = (original_shape[0] + values_per_item - 1) // values_per_item + + if len(original_shape) == 1: + packed_tensor_shape = (row_dim,) + else: + packed_tensor_shape = (row_dim, *original_shape[1:]) + + packed = torch.zeros(packed_tensor_shape, device=intweights.device, dtype=torch.uint8) + unpacked = intweights.to(torch.uint8) + + def lshift(t: torch.Tensor, bits: int): + return t << bits + + it = min(values_per_item, (original_shape[0] // row_dim) + 1) + for i in range(it): + start = i * row_dim + end = min(start + row_dim, original_shape[0]) + packed[: (end - start)] |= lshift(unpacked[start:end], bits * i) + + return packed + + +def get_autotune_config(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + }, + num_stages=4, + num_warps=4, + ), + ] + + +@triton.autotune( + configs=get_autotune_config(), + key=["M", "N", "K"], +) +@triton.jit +def matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K: tl.constexpr, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # We want K / 4 to be divisible by BLOCK_SIZE_K so that the multiplication can be aligned + tl.static_assert( + K % (4 * BLOCK_SIZE_K) == 0, + "K / 4 must be divisible by BLOCK_SIZE_K => K divisible by 4*BLOCK_SIZE_K", + ) + # determine the block id in the 1D grid, pid <=> blockId in cuda + pid = tl.program_id(axis=0) + # number of blocks we would need in the M dimension + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + # number of blocks we would need in the N dimension + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + # blocks are grouped along the M dimension. num_pid_in_group computes how many blocks are grouped together, + # and group_id calculates the group to which the current block (pid) belongs. + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + + # pid of the first block in the group that the current block belongs too + first_pid_m = group_id * GROUP_SIZE_M + + # pid_m : pid of the block along the M dimension of the output matrix, and pid_n : pid of the block along the N dimension of the output matrix + # remember that the grid of blocks is 1D, but we calculate pid_m and pid_n to locate the block pid place in the output matrix + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # offs_am represent the indices of elements within the block for matrices A with respect to the M dimension + # offs_bn represent the indices of elements within the block for matrices B with respect to the N dimension + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + """ + This part of the code generates pointers to the specific blocks of matrices A and B that the current thread block will process. + + As described in the PyTorch documentation, a stride refers to the step size needed to move from one element to the next along a given dimension: + + For matrix A: stride_am = A.stride(0) = K (stride along the rows), and stride_ak = A.stride(1) = 1 (stride along the columns). + For matrix B: stride_bk = B.stride(0) = N (stride along the rows), and stride_bn = B.stride(1) = 1 (stride along the columns). + Now, let's break down the pointer generation: + + offs_am[:, None] creates a column of shape [BLOCK_SIZE_M, 1], which represents the row indices of matrix A that this block is processing. It is multiplied by K (the number of columns in matrix A) since A is stored in row-major order. So, the element at position (i, j) in A is located at index i*K + j in memory. + offs_k[None, BLOCK_SIZE_K] creates a row vector representing the column indices of the block, i.e., a range from 0 to BLOCK_SIZE_K. This is used to compute the positions of the columns within the block. + When combined, the result has the shape [BLOCK_SIZE_M, BLOCK_SIZE_K], where each entry (i, j) points to the element in matrix A at position (i, j) for the current block. + + The same logic is applied to matrix B, but the resulting shape is [BLOCK_SIZE_K, BLOCK_SIZE_N], representing the block of matrix B that the thread block will work on. + """ + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # An accumulator matrix is initialized with zeros. It stores the intermediate results of the block matrix multiplication. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32) + """ + We split the loop into two layers. The outer loop runs 4 times, and each iteration focuses on a specific portion of matrix A. + + For example, when i = 0, we’re only concerned with the blocks of matrix A that cover the range from 0 to K // (4 * BLOCK_SIZE_K). + Since matrix B is packed, its first dimension is effectively divided by 4. So, while we process the first segment of matrix A, + we still iterate over the entire first dimension of matrix B. + + In each of the 4 iterations of the outer loop, we go through the full blocks of matrix B, but what changes is the data we extract. + Matrix B elements contain 4 weights, all packed into an int8 format, and during each iteration of the outer loop, + we extract a different weight by using bitwise shifting operations. This way, we access a unique weight on each pass. + """ + for i in range(4): + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + for j in range(0, tl.cdiv(K // 4, BLOCK_SIZE_K)): + k = i * tl.cdiv(K // 4, BLOCK_SIZE_K) + j + # load the block of matrix A + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0) + # load the block of matrix B + b_uint8 = tl.load(b_ptrs, mask=offs_k[:, None] < K, other=0) + # when i = 0 for example, we only care about the first 2 bits of the elements of the matrix B, so we use the mask 00000011 to mask the other bits + mask = 3 << (2 * i) + # we shift the results after the mask + b = (b_uint8 & mask) >> (2 * i) + # During the packing of the weights, it's easier to pack 0, 1, 2, then -1, 0, 1, so we add 1 to the weight tensor, and we substract it here + tensor_full = tl.full((1,), 1, dtype=tl.int8) + # We accumulate the result of multiplication of the blocks along the K dimension on int32 to avoid any overflows or underflows. + accumulator += tl.dot(a, (b.to(tl.int8) - tensor_full), out_dtype=tl.int32) + # we move the pointers, for the a_ptrs we more in a horizontal way along the second dimension -> we use strid_ak=1 + # for b_ptrs we move in a vertical way, along the rows -> we use stride_bk=N + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + c = accumulator + # These lines compute the offsets into matrix C where the result of this block’s computation should be stored. + # stride_cm = N & stride_cn = 1 + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + # we do a boundary check to ensure only elements within matrix bounds are stored + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul(a, b): + assert a.shape[1] == b.shape[0] * 4, "Incompatible dimensions, the weight matrix need to be packed" + assert a.is_contiguous(), "Matrix A must be contiguous" + M, K = a.shape + _, N = b.shape + # c is in int32 to avoid any overflows or underflows + c = torch.empty((M, N), device=a.device, dtype=torch.int32) + grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),) + matmul_kernel[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + ) + return c diff --git a/src/liger_kernel/ops/fused_add_rms_norm.py b/src/liger_kernel/ops/fused_add_rms_norm.py new file mode 100755 index 0000000000000000000000000000000000000000..866377687fd44f5374d4f1eef9e5e15cf8d9cbad --- /dev/null +++ b/src/liger_kernel/ops/fused_add_rms_norm.py @@ -0,0 +1,410 @@ +import math +import operator + +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.utils import calculate_settings +from liger_kernel.ops.utils import compare_version +from liger_kernel.ops.utils import ensure_contiguous +from liger_kernel.ops.utils import get_npu_core_count +from liger_kernel.ops.utils import set_large_grf_mode +from liger_kernel.ops.utils import torch_to_triton_dtype +from liger_kernel.utils import is_npu_available + +if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available(): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import rsqrt + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import rsqrt +else: + from triton.language.math import rsqrt + + +_CASTING_MODE_NONE: tl.constexpr = tl.constexpr(-1) +_CASTING_MODE_LLAMA: tl.constexpr = tl.constexpr(0) +_CASTING_MODE_GEMMA: tl.constexpr = tl.constexpr(1) + + +@triton.jit +def _fused_add_rms_norm_forward_kernel( + Y_ptr, + Y_row_stride, + S_ptr, # output residual + S_row_stride, + X_ptr, + X_row_stride, + R_ptr, # input residual + R_row_stride, + W_ptr, + W_row_stride, + RSTD_ptr, + RSTD_row_stride, + n_cols, + eps, + offset, + casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out + BLOCK_SIZE: tl.constexpr, +): + """ + This kernel computes the following: + 1. hidden_states = residual + hidden_states + 2. residual = hidden_states + 3. hidden_states = rmsnorm(hidden_states) + + This is a commonly used pattern in the decoder layers of LLMs. + Some examples: + 1. https://github.com/huggingface/transformers/blob/0dc2df5ddafe3cb5824ad24e85beba13e0aa6726/src/transformers/models/qwen3/modeling_qwen3.py#L271 + 2. https://github.com/huggingface/transformers/blob/0dc2df5ddafe3cb5824ad24e85beba13e0aa6726/src/transformers/models/llama4/modeling_llama4.py#L393 + + This kernel is inspired by the rms_norm forward kernel, and is adapted to support the residual addition in the forward pass. + The backward pass is also adapted to support the residual addition in the backward pass. + """ + + row_idx = tl.program_id(0).to(tl.int64) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + Y_ptr += row_idx * Y_row_stride + S_ptr += row_idx * S_row_stride + X_ptr += row_idx * X_row_stride + R_ptr += row_idx * R_row_stride + RSTD_ptr += row_idx * RSTD_row_stride + + X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0) + R_row = tl.load(R_ptr + col_offsets, mask=mask, other=0) + S_row = X_row + R_row + tl.store(S_ptr + col_offsets, S_row, mask=mask) + S_row_dtype = S_row.dtype + W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0) + + # On Llama, only rstd is computed on fp32 + if casting_mode == _CASTING_MODE_LLAMA: + S_row = S_row.to(tl.float32) + + # Gemma computes everything on fp32, and then casts back the output to the original dtype + if casting_mode == _CASTING_MODE_GEMMA: + W_row = W_row.to(tl.float32) + S_row = S_row.to(tl.float32) + + if casting_mode == _CASTING_MODE_NONE: + eps = eps.to(S_row_dtype) + offset = offset.to(S_row_dtype) + + mean_square = tl.sum(S_row * S_row, axis=0) / n_cols + rstd = rsqrt(mean_square + eps) + + # We can save time by caching rms with minimal memory overhead + # because rms is much smaller compared to X_row, as rms is for each row. + # However, on the computation side, it can save 4 operations (*, sum, /, sqrt). + tl.store(RSTD_ptr, rstd) + + S_row = S_row * rstd + + # On Llama, the multiplication with the weight is done on the original dtype + if casting_mode == _CASTING_MODE_LLAMA: + S_row = S_row.to(S_row_dtype) + + Y_row = S_row * (offset + W_row) + + if casting_mode == _CASTING_MODE_GEMMA: + Y_row = Y_row.to(S_row_dtype) + + tl.store(Y_ptr + col_offsets, Y_row, mask=mask) + + +@triton.jit +def _fused_add_rms_norm_backward_kernel( + dY_ptr, + dY_row_stride, + dS_out_ptr, + dS_out_row_stride, + dX_ptr, + dX_row_stride, + X_ptr, + X_row_stride, + X_dtype: tl.constexpr, + W_ptr, + W_row_stride, + RSTD_ptr, + RSTD_row_stride, + dW_ptr, + dW_row_stride, + n_rows, + n_cols, + offset, + rows_per_program: tl.constexpr, + casting_mode: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + has_dS_out: tl.constexpr, +): + """ + This kernel is adapted from the rms_norm backward kernel, and is adapted to support the residual + addition in the backward pass. For the following code pattern: + 1. hidden_states = residual + hidden_states + 2. residual = hidden_states + 3. hidden_states = rmsnorm(hidden_states) + + The gradient of hidden_states and residual comes out be exactly same. The value of this gradient is + the sum of the gradient of the hidden_states in step 3 and the gradient of the residual in step 2. + + The backward pass computation logic is same as the rms_norm backward kernel, except that the gradient + of the hidden_states in step 3 and the gradient of the residual in step 2 are summed up. + """ + + row_block_id = tl.program_id(0).to(tl.int64) + row_start = row_block_id * rows_per_program + row_end = min((row_block_id + 1) * rows_per_program, n_rows) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + + W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0) + W_row = W_row + offset + + for row_idx in range(row_start, row_end): + dy_base = dY_ptr + row_idx * dY_row_stride + dx_base = dX_ptr + row_idx * dX_row_stride + + x_base = X_ptr + row_idx * X_row_stride + rstd_base = RSTD_ptr + row_idx * RSTD_row_stride + + dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0) + X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0) + + # Get cached rms + rstd_row = tl.load(rstd_base) + + X_row = X_row.to(tl.float32) + + # Different bacward graphs for different casting modes + if casting_mode == _CASTING_MODE_LLAMA: + m = (dY_row * W_row).to(tl.float32) + + elif casting_mode == _CASTING_MODE_GEMMA: + dY_row = dY_row.to(tl.float32) + m = dY_row * W_row + else: + m = dY_row * W_row + + dX_row = rstd_row * m + + if has_dS_out: + ds_base = dS_out_ptr + row_idx * dS_out_row_stride + dS_out_row = tl.load(ds_base + col_offsets, mask=mask, other=0.0) + dX_row += (rstd_row) * ( + -(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row + ) + dS_out_row + else: + dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row) + + # calculate the gradient of W + if casting_mode == _CASTING_MODE_LLAMA: + dW_row += dY_row * (X_row * rstd_row).to(X_dtype) + else: + # here X_row is already in fp32 (see previous if block) + dW_row += dY_row * (X_row * rstd_row) + + tl.store(dx_base + col_offsets, dX_row.to(X_dtype), mask=mask) + + tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask) + + +_str_to_casting_mode = { + "llama": _CASTING_MODE_LLAMA.value, + "gemma": _CASTING_MODE_GEMMA.value, + "none": _CASTING_MODE_NONE.value, +} + + +def fused_add_rms_norm_forward(X, R, W, eps, offset, casting_mode): + if not isinstance(casting_mode, int): + assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}" + casting_mode = _str_to_casting_mode[casting_mode] + else: + assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}" + + shape = X.shape + dim = shape[-1] + X = X.view(-1, dim) + R = R.view(-1, dim) + n_rows, n_cols = X.shape + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) + S = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) + # RSTD is to cache rstd for each row + # RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode + rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype + RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device) + + # Check constraints. + assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]" + + # XPU-specific optimization + kernel_args = {} + if X.device.type == "xpu": + set_large_grf_mode(kernel_args) + + # TODO: add _block_fused_add_rms_norm_forward_kernel + _fused_add_rms_norm_forward_kernel[(n_rows,)]( + Y, + Y.stride(0), + S, + S.stride(0), + X, + X.stride(0), + R, + R.stride(0), + W, + W.stride(0), + RSTD, + RSTD.stride(0), + n_cols, + eps, + offset, + casting_mode, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + **kernel_args, # XPU-specific optimization + ) + + return Y.view(*shape), S.view(*shape), RSTD, BLOCK_SIZE, num_warps, casting_mode + + +def fused_add_rms_norm_backward(dY, dS_out, S, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place): + shape = dY.shape + dim = shape[-1] + dY = dY.view(-1, dim) + dS_out = dS_out.view(-1, dim) + S = S.view(-1, dim) + n_rows, n_cols = dY.shape + + sm_count = 1 + if S.device.type == "cuda": + sm_count = torch.cuda.get_device_properties(S.device).multi_processor_count + elif S.device.type == "xpu": + sm_count = torch.xpu.get_device_properties(S.device).gpu_eu_count + elif S.device.type == "npu": + sm_count = get_npu_core_count() + + # fp32 for numerical stability especially. + _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device) + + if n_cols > BLOCK_SIZE: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + rows_per_program = math.ceil(n_rows / sm_count) + grid = (sm_count,) + + if in_place is True: + dX = dY + else: + dX = torch.empty_like(dY) + + # XPU-specific optimization + kernel_args = {} + if S.device.type == "xpu": + set_large_grf_mode(kernel_args) + + # TODO: add _block_fused_add_rms_norm_backward_kernel + _fused_add_rms_norm_backward_kernel[grid]( + dY, + dY.stride(0), + dS_out, + dS_out.stride(0), + dX, + dX.stride(0), + S, + S.stride(0), + torch_to_triton_dtype[S.dtype], + W, + W.stride(0), + RSTD, + RSTD.stride(0), + _dW, + _dW.stride(0), + n_rows, + n_cols, + offset, + rows_per_program, + casting_mode, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + has_dS_out=dS_out is not None, + **kernel_args, # XPU-specific optimization + ) + + dX = dX.view(*shape) + dW = _dW.sum(dim=0).to(W.dtype) + + return dX, dX, dW # dR is equal to dX + + +class LigerFusedAddRMSNormFunction(torch.autograd.Function): + """ + Performs a fused operation that first adds a residual tensor to the hidden_states tensor (`X`), then applies RMSNorm (Root Mean Square Normalization) to the result using the weight tensor `W`, with optional offset and casting mode. + + This class implements the following sequence, commonly used in transformer decoder layers: + 1. hidden_states = residual + hidden_states + 2. residual = hidden_states (after addition) + 3. hidden_states = rmsnorm(hidden_states) + + Both the normalized hidden_states and the updated residual are returned as outputs. + + Some models use an 'offset' to shift the weight tensor `W` by a constant value. For example, Gemma + uses an offset of 1.0, so the computation becomes `(X / RMS(X)) * (W + 1.0)` instead of the usual + `(X / RMS(X)) * W`. You can pass the offset value as an argument to the forward function. + + In addition, different models cast their inputs at different places during RMSNorm computation. For + example, Gemma casts everything to fp32 before starting the computation, while Llama casts only the + inverse RMS to fp32. You can specify the casting mode using the `casting_mode` argument. We currently + support the following casting modes (they match HuggingFace Transformers' implementations): + - 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32. + - 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype. + - 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation. + + The `in_place` option determines whether to modify dY in-place to store dX. This defaults to `True` to save memory. + """ + + @staticmethod + @ensure_contiguous + def forward(ctx, X, R, W, eps, offset=0.0, casting_mode="llama", in_place=False): + """ + X: (B, T, H) or (BxT, H) + W: (H,) + """ + # TODO: add row_mode + Y, S, RSTD, BLOCK_SIZE, num_warps, casting_mode = fused_add_rms_norm_forward(X, R, W, eps, offset, casting_mode) + ctx.offset = offset + ctx.casting_mode = casting_mode + ctx.in_place = in_place + ctx.BLOCK_SIZE = BLOCK_SIZE + ctx.num_warps = num_warps + ctx.save_for_backward(S, W, RSTD) + return Y, S + + @staticmethod + @ensure_contiguous + def backward(ctx, dY, dS_out): + """ + Y: (B, T, H) or (BxT, H) + """ + S, W, RSTD = ctx.saved_tensors + dX, dR, dW = fused_add_rms_norm_backward( + dY, + dS_out, + S, + W, + RSTD, + ctx.offset, + ctx.casting_mode, + ctx.BLOCK_SIZE, + ctx.num_warps, + ctx.in_place, + ) + + return dX, dR, dW, None, None, None, None, None diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py new file mode 100755 index 0000000000000000000000000000000000000000..01f1b565866a2e2cc9b6c60e5898fc269250bcd5 --- /dev/null +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -0,0 +1,400 @@ +import torch +import triton + +from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel +from liger_kernel.ops.utils import amp_custom_bwd +from liger_kernel.ops.utils import amp_custom_fwd +from liger_kernel.ops.utils import element_mul_kernel +from liger_kernel.ops.utils import is_hip +from liger_kernel.utils import infer_device + +# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 +# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling +# The optimal maximum block size depends on your hardware, your kernel, and your dtype +MAX_FUSED_SIZE = 2048 if infer_device() == "npu" else 65536 // 2 + + +def fused_linear_cross_entropy_forward( + _input, + weight, + target, + ce_weight=None, + bias=None, + ignore_index=-100, + lse_square_scale=0.0, + label_smoothing=0.0, + reduction="mean", + softcap=None, + return_z_loss=False, + accum_dtype=None, + use_token_scaling=False, + return_token_accuracy=False, + return_predicted_tokens=False, +): + assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}" + assert isinstance(return_token_accuracy, bool), ( + f"return_token_accuracy must be True or False. Got: {return_token_accuracy}" + ) + assert isinstance(return_predicted_tokens, bool), ( + f"return_predicted_tokens must be True or False. Got: {return_predicted_tokens}" + ) + device = _input.device + + input_requires_grad = _input.requires_grad + + # inputs have shape: BT x H + # materialized activations will have shape: BT x V + # the increase in memory = BT x V + # reduction can be achieved by partitioning the number of tokens BT into smaller chunks. + # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be: + # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor + # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048 + BT, H = _input.shape + V = weight.shape[0] + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + inc_factor = triton.cdiv(V, H) # (V + H - 1) // H + chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor + num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size + + grad_input = torch.zeros_like(_input, device=device) + + # we use fp32 for loss and gradients accumulator + if input_requires_grad: + if accum_dtype is None: + grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None + grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None + else: + grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None + grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None + else: + grad_weight = None + grad_bias = None + + loss_1d = torch.zeros(BT, dtype=torch.float32, device=device) + z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None + token_accuracy_1d = torch.zeros(BT, dtype=torch.float32, device=device) if return_token_accuracy else None + predicted_tokens_1d = torch.full((BT,), -1, dtype=torch.int64, device=device) if return_predicted_tokens else None + + # TODO: evaluate how CUDA synchronization caused by .item() affects the speed + target_mask = target != ignore_index + total_n_non_ignore = target_mask.sum().item() + total_sum_non_ignore_ce_weight = total_n_non_ignore + ce_weight_sum = 0.0 + if ce_weight is not None: + assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}" + assert torch.is_floating_point(ce_weight), ( + f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}" + ) + total_sum_non_ignore_ce_weight = ( + torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item() + ) + ce_weight_sum = ce_weight.sum().item() + if ce_weight.stride(-1) != 1: + ce_weight = ce_weight.contiguous() + + for chunk_id in range(num_chunks): + start_idx = chunk_id * chunk_size + end_idx = min((chunk_id + 1) * chunk_size, BT) + _input_chunk = _input[start_idx:end_idx] # chunk_size x H + + # when doing matmul, use the original precision + logits_chunk = _input_chunk @ weight.t() # chunk_size x V + if bias is not None: + logits_chunk = logits_chunk + bias + + target_chunk = target[start_idx:end_idx] # chunk_size, + + n_rows = logits_chunk.shape[0] + + # Compute predicted probabilities for token scaling if needed + if use_token_scaling: + # Compute softmax probabilities for scaling + # We need to compute this before the cross entropy kernel modifies logits_chunk + logits_for_softmax = logits_chunk.detach().clone() # Detach to avoid gradient flow + if softcap is not None: + logits_for_softmax = softcap * torch.tanh(logits_for_softmax / softcap) + + # Compute softmax to get predicted probabilities + probs = torch.softmax(logits_for_softmax, dim=-1) + + # Get predicted probabilities for token scaling, handling ignored targets + valid_target_mask = target_chunk != ignore_index + valid_targets = target_chunk[valid_target_mask] + + if len(valid_targets) > 0: + # Gather probabilities only for valid targets + valid_probs = probs[valid_target_mask] + pred_probs_valid = torch.gather(valid_probs, -1, valid_targets.unsqueeze(-1)).squeeze(-1) + + # Create full tensor with zeros for ignored targets + pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device) + pred_probs[valid_target_mask] = pred_probs_valid + else: + # All targets are ignored + pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device) + + # Store the scaling factors + scaling_factors = pred_probs.detach() # Detach to ensure no gradient flow + + # unreduced loss + loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size, + z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None + token_accuracy_1d_slice = token_accuracy_1d[start_idx:end_idx] if return_token_accuracy else None + predicted_tokens_1d_slice = predicted_tokens_1d[start_idx:end_idx] if return_predicted_tokens else None + + # ensure _input and target are contiguous + logits_chunk = logits_chunk.contiguous() + target_chunk = target_chunk.contiguous() + + # Here we calculate the gradient of logits_chunk in place so we can save memory. + liger_cross_entropy_kernel[(n_rows,)]( + X_ptr=logits_chunk, + X_stride=logits_chunk.stride(-2), + Y_ptr=target_chunk, + Y_stride=target_chunk.stride(-1), # always 1 + weight_ptr=ce_weight, + loss_ptr=loss_1d_slice, + z_loss_ptr=z_loss_1d_slice, + loss_stride=loss_1d_slice.stride(-1), # always 1 + token_accuracy_ptr=token_accuracy_1d_slice, + token_accuracy_stride=token_accuracy_1d_slice.stride(-1) + if return_token_accuracy + else 0, # always 1 if accuracy is enabled + predicted_tokens_ptr=predicted_tokens_1d_slice, + predicted_tokens_stride=predicted_tokens_1d_slice.stride(-1) + if return_predicted_tokens + else 0, # always 1 if predicted tokens is enabled + n_cols=V, + n_non_ignore=total_n_non_ignore, + sum_non_ignore_weight=total_sum_non_ignore_ce_weight, + weight_sum=ce_weight_sum, + ignore_index=ignore_index, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + reduction=reduction, + softcap=softcap, + RETURN_Z_LOSS=return_z_loss, + RETURN_TOKEN_ACCURACY=return_token_accuracy, + RETURN_PREDICTED_TOKENS=return_predicted_tokens, + HAS_WEIGHT=True if ce_weight is not None else False, + HAS_SOFTCAPPING=True if softcap is not None else False, + HAS_GRADIENTS=input_requires_grad, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + + # Apply token scaling if requested + if use_token_scaling: + loss_1d_slice = loss_1d_slice * scaling_factors + if return_z_loss: + z_loss_1d_slice = z_loss_1d_slice * scaling_factors + + loss_1d[start_idx:end_idx] = loss_1d_slice + if return_z_loss: + z_loss_1d[start_idx:end_idx] = z_loss_1d_slice + if return_token_accuracy: + token_accuracy_1d[start_idx:end_idx] = token_accuracy_1d_slice + if return_predicted_tokens: + predicted_tokens_1d[start_idx:end_idx] = predicted_tokens_1d_slice + grad_logits_chunk = logits_chunk # chunk_size x V + + # Apply token scaling to gradients if requested + if use_token_scaling: + # Expand scaling factors to match gradient dimensions + scaling_factors_expanded = scaling_factors.unsqueeze(-1) # chunk_size x 1 + grad_logits_chunk = grad_logits_chunk * scaling_factors_expanded + + if input_requires_grad: + grad_input[start_idx:end_idx] = grad_logits_chunk @ weight + + if grad_weight is not None and input_requires_grad: + grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float() + + if bias is not None and input_requires_grad: + torch.add( + input=grad_bias, + other=grad_logits_chunk.sum(dim=0), + out=grad_bias, + alpha=1.0, + ) + + # Need extra calculations for backward if reduction=='none'. Not supporting reduction='none' now. + # if reduction == "none": + # loss = loss_1d + # z_loss = z_loss_1d if return_z_loss else None + + if reduction == "none": + # Return per-token losses + loss = loss_1d + z_loss = z_loss_1d if return_z_loss else None + token_accuracy = token_accuracy_1d if return_token_accuracy else None + else: + loss = torch.sum(loss_1d) + z_loss = torch.sum(z_loss_1d) if return_z_loss else None + # For accuracy, we compute the mean across all non-ignored tokens + token_accuracy = torch.sum(token_accuracy_1d) / total_n_non_ignore if return_token_accuracy else None + + predicted_tokens = predicted_tokens_1d if return_predicted_tokens else None + + # Cast back to original dtype + grad_weight = grad_weight.to(weight.dtype) if grad_weight is not None else None + grad_bias = grad_bias.to(bias.dtype) if grad_bias is not None else None + + return loss, z_loss, token_accuracy, predicted_tokens, grad_input, grad_weight, grad_bias + + +def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias): + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time + if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + BT, H = grad_input.shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H)) + + element_mul_kernel[(n_rows,)]( + grad_input, + grad_input.stride(-2), + grad_output, + H, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + + # handle grad_weight + if grad_weight is not None: + V, H = grad_weight.shape + n_rows = V + + element_mul_kernel[(n_rows,)]( + grad_weight, + grad_weight.stride(-2), + grad_output, + H, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + + if grad_bias is not None: + V = grad_bias.shape[0] + n_rows = V + + element_mul_kernel[(n_rows,)]( + grad_bias, + grad_bias.stride(-1), + grad_output, + 1, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + return grad_input, grad_weight, grad_bias + + +class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function): + @staticmethod + @amp_custom_fwd + def forward( + ctx, + _input, + weight, + target, + bias=None, + ce_weight=None, + ignore_index=-100, + lse_square_scale=0.0, + label_smoothing=0.0, + reduction="mean", + softcap=None, + return_z_loss: bool = False, + accum_dtype=None, + use_token_scaling: bool = False, + return_token_accuracy: bool = False, + return_predicted_tokens: bool = False, + ): + """ + Fusing the last linear layer with cross-entropy loss + Reference: https://github.com/mgmalek/efficient_cross_entropy + + Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding + the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can + compute the gradient at the forward pass. By doing so, we don't have to store the _input and target + for the backward pass. + + _input: (B*T, H) where B is batch size, T is sequence length, H is hidden dimension. + target: (B*T) where each value is in [0, V-1] + weight: (V, H) where V is the number of classes + bias: (V) where V is the number of classes + ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype + ignore_index: the index to ignore in the target + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + reduction: reduction to apply + accum_dtype (torch.dtype): the dtype of intermediate result buffers for weight and bias gradient accumulations. + Recommended to set `accum_dtype` to higher precision, e.g. `torch.float32`, if the training is unstable with original dtype. Default: `None`, performing accumulations in original dtype + use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached). + When True, each token's loss is multiplied by the model's predicted probability for that token's true class. + Default: False. + return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False` + return_predicted_tokens (bool): When `return_predicted_tokens` is `True`, returns per-token predicted class indices (argmax) without materializing logits. Default: `False` + """ + + loss, z_loss, token_accuracy, predicted_tokens, grad_input, grad_weight, grad_bias = ( + fused_linear_cross_entropy_forward( + _input=_input, + weight=weight, + target=target, + bias=bias, + ce_weight=ce_weight, + ignore_index=ignore_index, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + reduction=reduction, + softcap=softcap, + return_z_loss=return_z_loss, + accum_dtype=accum_dtype, + use_token_scaling=use_token_scaling, + return_token_accuracy=return_token_accuracy, + return_predicted_tokens=return_predicted_tokens, + ) + ) + # downcast to dtype and store for backward + ctx.save_for_backward( + grad_input.detach(), + grad_weight.detach() if grad_weight is not None else None, + grad_bias.detach() if grad_bias is not None else None, + ) + ctx.return_z_loss = return_z_loss + ctx.return_token_accuracy = return_token_accuracy + ctx.return_predicted_tokens = return_predicted_tokens + return loss, z_loss, token_accuracy, predicted_tokens + + @staticmethod + @amp_custom_bwd + def backward(ctx, grad_output, grad_output2, grad_output3, grad_output4): + if ctx.return_z_loss: + del grad_output2 # z_loss is only for logging + if ctx.return_token_accuracy: + del grad_output3 # token_accuracy is only for metrics + if ctx.return_predicted_tokens: + del grad_output4 # predicted_tokens is only for metrics + (grad_input, grad_weight, grad_bias) = ctx.saved_tensors + grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward( + grad_output, grad_input, grad_weight, grad_bias + ) + return ( + grad_input, + grad_weight, + None, + grad_bias, + None, + None, + None, + None, + None, + None, + None, + None, + None, # use_token_scaling + None, # return_token_accuracy + None, # return_predicted_tokens + ) diff --git a/src/liger_kernel/ops/fused_linear_jsd.py b/src/liger_kernel/ops/fused_linear_jsd.py new file mode 100755 index 0000000000000000000000000000000000000000..e31b10769b6004522cea805d459b5c59a5ed56b9 --- /dev/null +++ b/src/liger_kernel/ops/fused_linear_jsd.py @@ -0,0 +1,228 @@ +from typing import Optional + +import torch +import triton + +from liger_kernel.ops.jsd import _jsd_kernel +from liger_kernel.ops.utils import amp_custom_bwd +from liger_kernel.ops.utils import amp_custom_fwd +from liger_kernel.ops.utils import element_mul_kernel +from liger_kernel.ops.utils import is_hip +from liger_kernel.utils import infer_device + +# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 +# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling +# The optimal maximum block size depends on your hardware, your kernel, and your dtype +MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 + + +def fused_linear_jsd_forward( + student_input, + student_weight, + teacher_input, + teacher_weight, + shift_labels, + jsd_beta, + ignore_index, + has_label, + temperature, +): + device = student_input.device + dtype = student_input.dtype + + # inputs have shape: BT x H + # materialized activations will have shape: BT x V + # the increase in memory = BT x V + # reduction can be achieved by partitioning the number of tokens BT into smaller chunks. + # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be: + # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor + # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048 + BT, H = student_input.shape + V = student_weight.shape[0] + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + inc_factor = triton.cdiv(V, H) # (V + H - 1) // H + chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor + num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size + + grad_weight = torch.zeros_like(student_weight, device=device) if student_weight.requires_grad else None + grad_input = torch.zeros_like(student_input) + # we use fp32 for loss accumulator + loss_1d = torch.zeros((BT, V), dtype=torch.float32, device=device) + + if has_label: + n_non_ignore = (shift_labels != ignore_index).sum().item() + else: + n_non_ignore = BT + + for chunk_id in range(num_chunks): + start_idx = chunk_id * chunk_size + end_idx = min((chunk_id + 1) * chunk_size, BT) + + # chunk both inputs, shape: chunk_size x H + student_input_chunk = student_input[start_idx:end_idx] + teacher_input_chunk = teacher_input[start_idx:end_idx] + + # shape: chunk_size x V + # For anything starting from logits to the final JSD loss, we do computation + # in FP32 to avoid losing numerical stability. + student_logits_chunk = (student_input_chunk @ student_weight.t()).to(torch.float32) + teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(torch.float32) + chunk_n_rows = student_logits_chunk.shape[0] + + # unreduced loss + loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size + # log-softmax with temperature + student_logits_chunk = student_logits_chunk / temperature + teacher_logits_chunk = teacher_logits_chunk / temperature + student_prob_chunk = torch.log_softmax(student_logits_chunk, dim=-1) + teacher_prob_chunk = torch.log_softmax(teacher_logits_chunk, dim=-1) + + # ensure _input and target are contiguous + student_prob_chunk = student_prob_chunk.contiguous() + teacher_prob_chunk = teacher_prob_chunk.contiguous() + + # Here we calculate the gradient of prob_chunk in place so we can save memory. + _jsd_kernel[(chunk_n_rows,)]( + X_ptr=student_prob_chunk, + X_stride=student_prob_chunk.stride(-2), + Y_ptr=teacher_prob_chunk, + Y_stride=teacher_prob_chunk.stride(-2), + loss_ptr=loss_1d_slice, + loss_stride=loss_1d_slice.stride(-2), + dX_ptr=student_prob_chunk, + dX_stride=student_prob_chunk.stride(-2), + label_ptr=( + shift_labels[start_idx:end_idx] if has_label else torch.empty(1, device=device) + ), # dummy ptr if no label + beta=jsd_beta, + n_non_ignore=n_non_ignore, + ignore_index=ignore_index, + n_cols=V, + BLOCK_SIZE=BLOCK_SIZE, + HAS_LABEL=has_label, + ) + loss_1d[start_idx:end_idx] = loss_1d_slice + # gradients of prob_chunk in place, shape: chunk_size x V + # gradients of logits_chunk in place, shape: chunk_size x V + student_logits_chunk = ( + student_prob_chunk + - torch.softmax(student_logits_chunk, dim=-1) + * student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(student_prob_chunk.shape) + ) / temperature + # now we traverse back to grad w.r.t. input to `lm_head` and grad + # w.r.t. `lm_head` which should be computed in original dtype + student_logits_chunk = student_logits_chunk.to(dtype) + grad_input[start_idx:end_idx] = student_logits_chunk @ student_weight + + if grad_weight is not None: + grad_weight.add_(student_logits_chunk.t() @ student_input_chunk) + + loss = torch.sum(loss_1d) + return loss, grad_input, grad_weight + + +def fused_linear_jsd_backward(grad_output, grad_input, grad_weight): + # If JSD is the last layer, grad_output is 1.0. Skip the mul to save time + if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)): + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + BT, H = grad_input.shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H)) + + element_mul_kernel[(n_rows,)]( + grad_input, + grad_input.stride(-2), + grad_output, + H, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + + # handle grad_weight + if grad_weight is not None: + V, H = grad_weight.shape + n_rows = V + + element_mul_kernel[(n_rows,)]( + grad_weight, + grad_weight.stride(-2), + grad_output, + H, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + + return grad_input, grad_weight + + +class LigerFusedLinearJSDFunction(torch.autograd.Function): + """ + Fusing the last linear layer with generalized JSD + + Handle the forward and backward pass of the final linear layer via JSD by avoiding + the materialization of the large logits tensor. Since JSD is the last layer, we can + compute the gradient at the forward pass. + """ + + @staticmethod + @amp_custom_fwd + def forward( + ctx, + student_input: torch.Tensor, + student_weight: torch.Tensor, + teacher_input: torch.Tensor, + teacher_weight: torch.Tensor, + shift_labels: Optional[torch.Tensor] = None, + jsd_beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + ): + """ + Args: + + student_input (torch.tensor): input of the last projection layer in student model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension. + student_weight (torch.tensor): the last projection layer in student model, with shape (V, H), where V is vocab size + teacher_input (torch.tensor): input of the last projection layer in teacher model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension. + teacher_weight (torch.tensor): the last projection layer in teacher model, with shape (V, H), where V is vocab size + shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1]. + jsd_beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5` + ignore_index (int): the index to ignore. Default: -100 + temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0` + + Returns: + loss (torch.Tensor): generalized JSD + """ + has_label = False + if shift_labels is not None: + assert shift_labels.shape == (teacher_input.shape[0],), ( + f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}" + ) + shift_labels = shift_labels.contiguous() + has_label = True + + loss, grad_input, grad_weight = fused_linear_jsd_forward( + student_input, + student_weight, + teacher_input, + teacher_weight, + shift_labels, + jsd_beta, + ignore_index, + has_label, + temperature, + ) + # downcast to dtype and store for backward + ctx.save_for_backward( + grad_input.detach(), + grad_weight.detach() if grad_weight is not None else None, + ) + return loss + + @staticmethod + @amp_custom_bwd + def backward(ctx, grad_output): + (grad_input, grad_weight) = ctx.saved_tensors + grad_input, grad_weight = fused_linear_jsd_backward(grad_output, grad_input, grad_weight) + return (grad_input, grad_weight, None, None, None, None, None, None) diff --git a/src/liger_kernel/ops/fused_neighborhood_attention.py b/src/liger_kernel/ops/fused_neighborhood_attention.py new file mode 100755 index 0000000000000000000000000000000000000000..557358fc9a1cb2916aaeb6a61c6da47ba4a9f3fc --- /dev/null +++ b/src/liger_kernel/ops/fused_neighborhood_attention.py @@ -0,0 +1,1022 @@ +import math + +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.softmax import _softmax_backward +from liger_kernel.ops.softmax import _softmax_forward +from liger_kernel.ops.utils import calculate_settings +from liger_kernel.ops.utils import ensure_contiguous + + +@triton.jit +def _neighborhood_mask_kernel( + mask_ptr, + seq_len: tl.constexpr, + kernel_size: tl.constexpr, + dilation: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + num_stages: tl.constexpr, + num_warps: tl.constexpr, +): + """ + Generate a neighborhood attention mask for a given sequence. + + This kernel creates a binary mask that defines which positions in a sequence + can attend to each other based on a neighborhood window with optional dilation. + Each row of the mask corresponds to a query position, and each column indicates + whether that key position is within the allowed neighborhood. + + The neighborhood is defined as positions within kernel_size//2 * dilation distance + from the center position. When dilation > 1, only positions at multiples of the + dilation factor are included in the neighborhood. + + Args: + mask_ptr: Pointer to the output mask tensor [seq_len, seq_len] + seq_len: Length of the input sequence + kernel_size: Size of the neighborhood window (must be odd) + dilation: Dilation factor for the neighborhood pattern + BLOCK_SIZE: Block size for processing (compile-time constant) + num_stages: Number of pipeline stages (compile-time constant) + num_warps: Number of warps (compile-time constant) + + Grid: (seq_len,) + Each program processes one row of the mask matrix. + """ + row_id = tl.program_id(0) + + center = row_id + half_kernel = kernel_size // 2 + + start = tl.maximum(0, center - half_kernel * dilation) + end = tl.minimum(seq_len, center + half_kernel * dilation + 1) + + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < seq_len + + valid_neighbors = (col_offsets >= start) & (col_offsets < end) + if dilation > 1: + relative_pos = col_offsets - center + valid_dilation = (relative_pos % dilation) == 0 + valid_neighbors = valid_neighbors & valid_dilation + + mask_values = tl.where(valid_neighbors & mask, 1.0, 0.0) + + base_offset = row_id * seq_len + tl.store(mask_ptr + base_offset + col_offsets, mask_values, mask=mask) + + +@triton.jit +def _fused_neighborhood_attention_qk_kernel( + Q_ptr, + K_ptr, + QK_ptr, + mask_ptr, + q_batch_stride, + q_head_stride, + q_seq_stride, + q_dim_stride, + k_batch_stride, + k_head_stride, + k_seq_stride, + k_dim_stride, + qk_batch_stride, + qk_head_stride, + qk_seq_stride, + qk_seq2_stride, + batch_size: tl.constexpr, + num_heads: tl.constexpr, + seq_len: tl.constexpr, + head_dim: tl.constexpr, + scale: tl.constexpr, + kernel_size: tl.constexpr, + dilation: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + num_stages: tl.constexpr, + num_warps: tl.constexpr, +): + """ + Compute Q @ K^T with neighborhood masking and scaling. + + This kernel performs the first stage of neighborhood attention by computing + the attention scores between queries and keys, applying scaling, and masking + positions outside the neighborhood window. The result is a matrix of attention + scores ready for softmax normalization. + + The computation is tiled across sequence dimensions for memory efficiency. + Each tile computes a block of the attention score matrix by iterating over + the head dimension and accumulating dot products. + + Args: + Q_ptr: Pointer to query tensor [batch_size, num_heads, seq_len, head_dim] + K_ptr: Pointer to key tensor [batch_size, num_heads, seq_len, head_dim] + QK_ptr: Pointer to output tensor [batch_size, num_heads, seq_len, seq_len] + mask_ptr: Pointer to neighborhood mask [seq_len, seq_len] + q_*_stride: Strides for query tensor + k_*_stride: Strides for key tensor + qk_*_stride: Strides for output tensor + batch_size: Number of batches + num_heads: Number of attention heads + seq_len: Sequence length + head_dim: Dimension of each attention head + scale: Scaling factor for attention scores (typically 1/sqrt(head_dim)) + kernel_size: Size of the neighborhood window + dilation: Dilation factor for the neighborhood + BLOCK_SIZE_M: Block size for sequence dimension (rows) + BLOCK_SIZE_N: Block size for sequence dimension (cols) + BLOCK_SIZE_K: Block size for head dimension + num_stages: Number of pipeline stages + num_warps: Number of warps + + Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(seq_len, BLOCK_SIZE_N)) + Each program computes a tile of the attention score matrix. + """ + batch_head_id = tl.program_id(0) + tile_m = tl.program_id(1) + tile_n = tl.program_id(2) + + batch_id = batch_head_id // num_heads + head_id = batch_head_id % num_heads + + row_start = tile_m * BLOCK_SIZE_M + col_start = tile_n * BLOCK_SIZE_N + + row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M) + col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k_start in range(0, head_dim, BLOCK_SIZE_K): + k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K) + k_mask = k_offsets < head_dim + + q_ptrs = ( + Q_ptr + + batch_id * q_batch_stride + + head_id * q_head_stride + + row_offsets[:, None] * q_seq_stride + + k_offsets[None, :] * q_dim_stride + ) + q_mask = (row_offsets[:, None] < seq_len) & k_mask[None, :] + q_chunk = tl.load(q_ptrs, mask=q_mask, other=0.0) + + k_ptrs = ( + K_ptr + + batch_id * k_batch_stride + + head_id * k_head_stride + + col_offsets[:, None] * k_seq_stride + + k_offsets[None, :] * k_dim_stride + ) + k_mask = (col_offsets[:, None] < seq_len) & k_mask[None, :] + k_chunk = tl.load(k_ptrs, mask=k_mask, other=0.0) + + acc += tl.dot(q_chunk, tl.trans(k_chunk)) + + acc = acc * scale + + mask_ptrs = mask_ptr + row_offsets[:, None] * seq_len + col_offsets[None, :] + valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < seq_len) + neighborhood_mask = tl.load(mask_ptrs, mask=valid_mask, other=0.0) + + acc = tl.where(neighborhood_mask > 0.0, acc, float("-inf")) + + qk_ptrs = ( + QK_ptr + + batch_id * qk_batch_stride + + head_id * qk_head_stride + + row_offsets[:, None] * qk_seq_stride + + col_offsets[None, :] * qk_seq2_stride + ) + tl.store(qk_ptrs, acc, mask=valid_mask) + + +@triton.jit +def _fused_neighborhood_attention_av_kernel( + Attn_ptr, + V_ptr, + Out_ptr, + attn_batch_stride, + attn_head_stride, + attn_seq_stride, + attn_seq2_stride, + v_batch_stride, + v_head_stride, + v_seq_stride, + v_dim_stride, + out_batch_stride, + out_head_stride, + out_seq_stride, + out_dim_stride, + batch_size: tl.constexpr, + num_heads: tl.constexpr, + seq_len: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + num_stages: tl.constexpr, + num_warps: tl.constexpr, +): + """ + Compute Attention @ V to produce the final output. + + This kernel performs the second stage of neighborhood attention by multiplying + the normalized attention weights with the value matrix. The computation is + tiled for memory efficiency, with each tile computing a block of the output. + + Args: + Attn_ptr: Pointer to attention weights [batch_size, num_heads, seq_len, seq_len] + V_ptr: Pointer to value tensor [batch_size, num_heads, seq_len, head_dim] + Out_ptr: Pointer to output tensor [batch_size, num_heads, seq_len, head_dim] + attn_*_stride: Strides for attention weights tensor + v_*_stride: Strides for value tensor + out_*_stride: Strides for output tensor + batch_size: Number of batches + num_heads: Number of attention heads + seq_len: Sequence length + head_dim: Dimension of each attention head + BLOCK_SIZE_M: Block size for sequence dimension (rows) + BLOCK_SIZE_N: Block size for head dimension (cols) + BLOCK_SIZE_K: Block size for sequence dimension (reduction) + num_stages: Number of pipeline stages + num_warps: Number of warps + + Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(head_dim, BLOCK_SIZE_N)) + Each program computes a tile of the output matrix. + """ + batch_head_id = tl.program_id(0) + tile_m = tl.program_id(1) + tile_n = tl.program_id(2) + + batch_id = batch_head_id // num_heads + head_id = batch_head_id % num_heads + + row_start = tile_m * BLOCK_SIZE_M + col_start = tile_n * BLOCK_SIZE_N + + row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M) + col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k_start in range(0, seq_len, BLOCK_SIZE_K): + k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K) + k_mask = k_offsets < seq_len + + attn_ptrs = ( + Attn_ptr + + batch_id * attn_batch_stride + + head_id * attn_head_stride + + row_offsets[:, None] * attn_seq_stride + + k_offsets[None, :] * attn_seq2_stride + ) + attn_mask = (row_offsets[:, None] < seq_len) & k_mask[None, :] + attn_chunk = tl.load(attn_ptrs, mask=attn_mask, other=0.0) + + v_ptrs = ( + V_ptr + + batch_id * v_batch_stride + + head_id * v_head_stride + + k_offsets[:, None] * v_seq_stride + + col_offsets[None, :] * v_dim_stride + ) + v_mask = k_mask[:, None] & (col_offsets[None, :] < head_dim) + v_chunk = tl.load(v_ptrs, mask=v_mask, other=0.0) + + acc += tl.dot(attn_chunk, v_chunk) + + out_ptrs = ( + Out_ptr + + batch_id * out_batch_stride + + head_id * out_head_stride + + row_offsets[:, None] * out_seq_stride + + col_offsets[None, :] * out_dim_stride + ) + valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < head_dim) + tl.store(out_ptrs, acc, mask=valid_mask) + + +@triton.jit +def _fused_neighborhood_attention_grad_qk_kernel( + grad_attn_ptr, + K_ptr, + grad_Q_ptr, + grad_attn_batch_stride, + grad_attn_head_stride, + grad_attn_seq_stride, + grad_attn_seq2_stride, + k_batch_stride, + k_head_stride, + k_seq_stride, + k_dim_stride, + grad_q_batch_stride, + grad_q_head_stride, + grad_q_seq_stride, + grad_q_dim_stride, + batch_size: tl.constexpr, + num_heads: tl.constexpr, + seq_len: tl.constexpr, + head_dim: tl.constexpr, + scale: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + num_stages: tl.constexpr, + num_warps: tl.constexpr, +): + """ + Compute gradient with respect to queries: grad_Q = grad_attn @ K * scale. + + This kernel computes the gradient of the loss with respect to the query tensor + by multiplying the gradient of attention weights with the key tensor. The + computation follows the chain rule for the attention mechanism. + + Args: + grad_attn_ptr: Pointer to gradient of attention weights [batch_size, num_heads, seq_len, seq_len] + K_ptr: Pointer to key tensor [batch_size, num_heads, seq_len, head_dim] + grad_Q_ptr: Pointer to output gradient tensor [batch_size, num_heads, seq_len, head_dim] + grad_attn_*_stride: Strides for gradient attention tensor + k_*_stride: Strides for key tensor + grad_q_*_stride: Strides for gradient query tensor + batch_size: Number of batches + num_heads: Number of attention heads + seq_len: Sequence length + head_dim: Dimension of each attention head + scale: Scaling factor applied to attention scores + BLOCK_SIZE_M: Block size for sequence dimension (rows) + BLOCK_SIZE_N: Block size for head dimension (cols) + BLOCK_SIZE_K: Block size for sequence dimension (reduction) + num_stages: Number of pipeline stages + num_warps: Number of warps + + Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(head_dim, BLOCK_SIZE_N)) + Each program computes a tile of the query gradient matrix. + """ + batch_head_id = tl.program_id(0) + tile_m = tl.program_id(1) + tile_n = tl.program_id(2) + + batch_id = batch_head_id // num_heads + head_id = batch_head_id % num_heads + + row_start = tile_m * BLOCK_SIZE_M + col_start = tile_n * BLOCK_SIZE_N + + row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M) + col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k_start in range(0, seq_len, BLOCK_SIZE_K): + k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K) + k_mask = k_offsets < seq_len + + grad_attn_ptrs = ( + grad_attn_ptr + + batch_id * grad_attn_batch_stride + + head_id * grad_attn_head_stride + + row_offsets[:, None] * grad_attn_seq_stride + + k_offsets[None, :] * grad_attn_seq2_stride + ) + grad_attn_mask = (row_offsets[:, None] < seq_len) & k_mask[None, :] + grad_attn_chunk = tl.load(grad_attn_ptrs, mask=grad_attn_mask, other=0.0) + + k_ptrs = ( + K_ptr + + batch_id * k_batch_stride + + head_id * k_head_stride + + k_offsets[:, None] * k_seq_stride + + col_offsets[None, :] * k_dim_stride + ) + k_mask_2d = k_mask[:, None] & (col_offsets[None, :] < head_dim) + k_chunk = tl.load(k_ptrs, mask=k_mask_2d, other=0.0) + + acc += tl.dot(grad_attn_chunk, k_chunk) + + acc = acc * scale + + grad_q_ptrs = ( + grad_Q_ptr + + batch_id * grad_q_batch_stride + + head_id * grad_q_head_stride + + row_offsets[:, None] * grad_q_seq_stride + + col_offsets[None, :] * grad_q_dim_stride + ) + valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < head_dim) + tl.store(grad_q_ptrs, acc, mask=valid_mask) + + +@triton.jit +def _fused_neighborhood_attention_grad_k_kernel( + grad_attn_ptr, + Q_ptr, + grad_K_ptr, + grad_attn_batch_stride, + grad_attn_head_stride, + grad_attn_seq_stride, + grad_attn_seq2_stride, + q_batch_stride, + q_head_stride, + q_seq_stride, + q_dim_stride, + grad_k_batch_stride, + grad_k_head_stride, + grad_k_seq_stride, + grad_k_dim_stride, + batch_size: tl.constexpr, + num_heads: tl.constexpr, + seq_len: tl.constexpr, + head_dim: tl.constexpr, + scale: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + num_stages: tl.constexpr, + num_warps: tl.constexpr, +): + """ + Compute gradient with respect to keys: grad_K = grad_attn^T @ Q * scale. + + This kernel computes the gradient of the loss with respect to the key tensor + by multiplying the transpose of the gradient of attention weights with the + query tensor. The computation follows the chain rule for the attention mechanism. + + Args: + grad_attn_ptr: Pointer to gradient of attention weights [batch_size, num_heads, seq_len, seq_len] + Q_ptr: Pointer to query tensor [batch_size, num_heads, seq_len, head_dim] + grad_K_ptr: Pointer to output gradient tensor [batch_size, num_heads, seq_len, head_dim] + grad_attn_*_stride: Strides for gradient attention tensor + q_*_stride: Strides for query tensor + grad_k_*_stride: Strides for gradient key tensor + batch_size: Number of batches + num_heads: Number of attention heads + seq_len: Sequence length + head_dim: Dimension of each attention head + scale: Scaling factor applied to attention scores + BLOCK_SIZE_M: Block size for sequence dimension (rows) + BLOCK_SIZE_N: Block size for head dimension (cols) + BLOCK_SIZE_K: Block size for sequence dimension (reduction) + num_stages: Number of pipeline stages + num_warps: Number of warps + + Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(head_dim, BLOCK_SIZE_N)) + Each program computes a tile of the key gradient matrix. + """ + batch_head_id = tl.program_id(0) + tile_m = tl.program_id(1) + tile_n = tl.program_id(2) + + batch_id = batch_head_id // num_heads + head_id = batch_head_id % num_heads + + row_start = tile_m * BLOCK_SIZE_M + col_start = tile_n * BLOCK_SIZE_N + + row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M) + col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k_start in range(0, seq_len, BLOCK_SIZE_K): + k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K) + k_mask = k_offsets < seq_len + + q_ptrs = ( + Q_ptr + + batch_id * q_batch_stride + + head_id * q_head_stride + + k_offsets[:, None] * q_seq_stride + + col_offsets[None, :] * q_dim_stride + ) + q_mask = k_mask[:, None] & (col_offsets[None, :] < head_dim) + q_chunk = tl.load(q_ptrs, mask=q_mask, other=0.0) + + grad_attn_T_ptrs = ( + grad_attn_ptr + + batch_id * grad_attn_batch_stride + + head_id * grad_attn_head_stride + + row_offsets[:, None] * grad_attn_seq2_stride + + k_offsets[None, :] * grad_attn_seq_stride + ) + grad_attn_T_mask = (row_offsets[:, None] < seq_len) & k_mask[None, :] + grad_attn_T_chunk = tl.load(grad_attn_T_ptrs, mask=grad_attn_T_mask, other=0.0) + + acc += tl.dot(grad_attn_T_chunk, q_chunk) + + acc = acc * scale + + grad_k_ptrs = ( + grad_K_ptr + + batch_id * grad_k_batch_stride + + head_id * grad_k_head_stride + + row_offsets[:, None] * grad_k_seq_stride + + col_offsets[None, :] * grad_k_dim_stride + ) + valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < head_dim) + tl.store(grad_k_ptrs, acc, mask=valid_mask) + + +@triton.jit +def _fused_neighborhood_attention_grad_v_kernel( + Attn_ptr, + grad_output_ptr, + grad_V_ptr, + attn_batch_stride, + attn_head_stride, + attn_seq_stride, + attn_seq2_stride, + grad_out_batch_stride, + grad_out_head_stride, + grad_out_seq_stride, + grad_out_dim_stride, + grad_v_batch_stride, + grad_v_head_stride, + grad_v_seq_stride, + grad_v_dim_stride, + batch_size: tl.constexpr, + num_heads: tl.constexpr, + seq_len: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + num_stages: tl.constexpr, + num_warps: tl.constexpr, +): + """ + Compute gradient with respect to values: grad_V = Attn^T @ grad_output. + + This kernel computes the gradient of the loss with respect to the value tensor + by multiplying the transpose of the attention weights with the gradient of the + output. The computation follows the chain rule for the attention mechanism. + + Args: + Attn_ptr: Pointer to attention weights [batch_size, num_heads, seq_len, seq_len] + grad_output_ptr: Pointer to gradient of output [batch_size, num_heads, seq_len, head_dim] + grad_V_ptr: Pointer to output gradient tensor [batch_size, num_heads, seq_len, head_dim] + attn_*_stride: Strides for attention weights tensor + grad_out_*_stride: Strides for gradient output tensor + grad_v_*_stride: Strides for gradient value tensor + batch_size: Number of batches + num_heads: Number of attention heads + seq_len: Sequence length + head_dim: Dimension of each attention head + BLOCK_SIZE_M: Block size for sequence dimension (rows) + BLOCK_SIZE_N: Block size for head dimension (cols) + BLOCK_SIZE_K: Block size for sequence dimension (reduction) + num_stages: Number of pipeline stages + num_warps: Number of warps + + Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(head_dim, BLOCK_SIZE_N)) + Each program computes a tile of the value gradient matrix. + """ + batch_head_id = tl.program_id(0) + tile_m = tl.program_id(1) + tile_n = tl.program_id(2) + + batch_id = batch_head_id // num_heads + head_id = batch_head_id % num_heads + + row_start = tile_m * BLOCK_SIZE_M + col_start = tile_n * BLOCK_SIZE_N + + row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M) + col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k_start in range(0, seq_len, BLOCK_SIZE_K): + k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K) + k_mask = k_offsets < seq_len + + attn_ptrs = ( + Attn_ptr + + batch_id * attn_batch_stride + + head_id * attn_head_stride + + k_offsets[:, None] * attn_seq_stride + + row_offsets[None, :] * attn_seq2_stride + ) + attn_mask = k_mask[:, None] & (row_offsets[None, :] < seq_len) + attn_chunk = tl.load(attn_ptrs, mask=attn_mask, other=0.0) + + grad_out_ptrs = ( + grad_output_ptr + + batch_id * grad_out_batch_stride + + head_id * grad_out_head_stride + + k_offsets[:, None] * grad_out_seq_stride + + col_offsets[None, :] * grad_out_dim_stride + ) + grad_out_mask = k_mask[:, None] & (col_offsets[None, :] < head_dim) + grad_out_chunk = tl.load(grad_out_ptrs, mask=grad_out_mask, other=0.0) + + acc += tl.dot(tl.trans(attn_chunk), grad_out_chunk) + + grad_v_ptrs = ( + grad_V_ptr + + batch_id * grad_v_batch_stride + + head_id * grad_v_head_stride + + row_offsets[:, None] * grad_v_seq_stride + + col_offsets[None, :] * grad_v_dim_stride + ) + valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < head_dim) + tl.store(grad_v_ptrs, acc, mask=valid_mask) + + +@triton.jit +def _fused_neighborhood_attention_grad_attn_kernel( + grad_output_ptr, + V_ptr, + grad_attn_ptr, + grad_out_batch_stride, + grad_out_head_stride, + grad_out_seq_stride, + grad_out_dim_stride, + v_batch_stride, + v_head_stride, + v_seq_stride, + v_dim_stride, + grad_attn_batch_stride, + grad_attn_head_stride, + grad_attn_seq_stride, + grad_attn_seq2_stride, + batch_size: tl.constexpr, + num_heads: tl.constexpr, + seq_len: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + num_stages: tl.constexpr, + num_warps: tl.constexpr, +): + """ + Compute gradient with respect to attention weights: grad_attn = grad_output @ V^T. + + This kernel computes the gradient of the loss with respect to the attention + weights by multiplying the gradient of the output with the transpose of the + value tensor. This gradient will later be passed through the softmax backward + pass to compute gradients for the attention scores. + + Args: + grad_output_ptr: Pointer to gradient of output [batch_size, num_heads, seq_len, head_dim] + V_ptr: Pointer to value tensor [batch_size, num_heads, seq_len, head_dim] + grad_attn_ptr: Pointer to output gradient tensor [batch_size, num_heads, seq_len, seq_len] + grad_out_*_stride: Strides for gradient output tensor + v_*_stride: Strides for value tensor + grad_attn_*_stride: Strides for gradient attention tensor + batch_size: Number of batches + num_heads: Number of attention heads + seq_len: Sequence length + head_dim: Dimension of each attention head + BLOCK_SIZE_M: Block size for sequence dimension (rows) + BLOCK_SIZE_N: Block size for sequence dimension (cols) + BLOCK_SIZE_K: Block size for head dimension (reduction) + num_stages: Number of pipeline stages + num_warps: Number of warps + + Grid: (batch_size * num_heads, cdiv(seq_len, BLOCK_SIZE_M), cdiv(seq_len, BLOCK_SIZE_N)) + Each program computes a tile of the attention gradient matrix. + """ + batch_head_id = tl.program_id(0) + tile_m = tl.program_id(1) + tile_n = tl.program_id(2) + + batch_id = batch_head_id // num_heads + head_id = batch_head_id % num_heads + + row_start = tile_m * BLOCK_SIZE_M + col_start = tile_n * BLOCK_SIZE_N + + row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M) + col_offsets = col_start + tl.arange(0, BLOCK_SIZE_N) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k_start in range(0, head_dim, BLOCK_SIZE_K): + k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K) + k_mask = k_offsets < head_dim + + grad_out_ptrs = ( + grad_output_ptr + + batch_id * grad_out_batch_stride + + head_id * grad_out_head_stride + + row_offsets[:, None] * grad_out_seq_stride + + k_offsets[None, :] * grad_out_dim_stride + ) + grad_out_mask = (row_offsets[:, None] < seq_len) & k_mask[None, :] + grad_out_chunk = tl.load(grad_out_ptrs, mask=grad_out_mask, other=0.0) + + v_ptrs = ( + V_ptr + + batch_id * v_batch_stride + + head_id * v_head_stride + + col_offsets[None, :] * v_seq_stride + + k_offsets[:, None] * v_dim_stride + ) + v_mask = (col_offsets[None, :] < seq_len) & k_mask[:, None] + v_chunk = tl.load(v_ptrs, mask=v_mask, other=0.0) + + acc += tl.dot(grad_out_chunk, v_chunk) + + grad_attn_ptrs = ( + grad_attn_ptr + + batch_id * grad_attn_batch_stride + + head_id * grad_attn_head_stride + + row_offsets[:, None] * grad_attn_seq_stride + + col_offsets[None, :] * grad_attn_seq2_stride + ) + valid_mask = (row_offsets[:, None] < seq_len) & (col_offsets[None, :] < seq_len) + tl.store(grad_attn_ptrs, acc, mask=valid_mask) + + +def fused_neighborhood_attention_forward( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kernel_size: int = 7, + dilation: int = 1, + scale: float = None, + return_lse: bool = False, +) -> tuple: + """ + Fused neighborhood attention forward pass. + + Args: + query: Query tensor of shape [batch_size, num_heads, seq_len, head_dim] + key: Key tensor of shape [batch_size, num_heads, seq_len, head_dim] + value: Value tensor of shape [batch_size, num_heads, seq_len, head_dim] + kernel_size: Size of the neighborhood window + dilation: Dilation factor for the neighborhood + scale: Scaling factor for attention scores (default: rsqrt(head_dim)) + return_lse: Whether to return log-sum-exp values + + Returns: + Tuple of (output tensor, softmax parameters for backward) + """ + batch_size, num_heads, seq_len, head_dim = query.shape + + if scale is None: + scale = 1.0 / math.sqrt(head_dim) + + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + output = torch.empty_like(query) + qk_scores = torch.empty(batch_size, num_heads, seq_len, seq_len, device=query.device, dtype=query.dtype) + + mask = torch.zeros(seq_len, seq_len, device=query.device, dtype=torch.float32) + + BLOCK_SIZE, num_warps = calculate_settings(seq_len) + BLOCK_SIZE_M = min(64, triton.next_power_of_2(seq_len)) + BLOCK_SIZE_N = min(64, triton.next_power_of_2(seq_len)) + BLOCK_SIZE_K = max(16, triton.next_power_of_2(head_dim)) + + num_stages = 4 if seq_len >= 512 else 2 + + grid_mask = (seq_len,) + _neighborhood_mask_kernel[grid_mask]( + mask, + seq_len, + kernel_size, + dilation, + BLOCK_SIZE, + num_stages, + num_warps, + ) + + grid_qk = (batch_size * num_heads, triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(seq_len, BLOCK_SIZE_N)) + _fused_neighborhood_attention_qk_kernel[grid_qk]( + query, + key, + qk_scores, + mask, + query.stride(0), + query.stride(1), + query.stride(2), + query.stride(3), + key.stride(0), + key.stride(1), + key.stride(2), + key.stride(3), + qk_scores.stride(0), + qk_scores.stride(1), + qk_scores.stride(2), + qk_scores.stride(3), + batch_size, + num_heads, + seq_len, + head_dim, + scale, + kernel_size, + dilation, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + BLOCK_SIZE_K, + num_stages, + num_warps, + ) + + qk_reshaped = qk_scores.view(batch_size * num_heads * seq_len, seq_len) + attn_reshaped, BLOCK_SIZE_softmax, num_warps_softmax, multi_block_launch = _softmax_forward(qk_reshaped) + attn_weights = attn_reshaped.view(batch_size, num_heads, seq_len, seq_len) + + grid_av = (batch_size * num_heads, triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(head_dim, BLOCK_SIZE_N)) + _fused_neighborhood_attention_av_kernel[grid_av]( + attn_weights, + value, + output, + attn_weights.stride(0), + attn_weights.stride(1), + attn_weights.stride(2), + attn_weights.stride(3), + value.stride(0), + value.stride(1), + value.stride(2), + value.stride(3), + output.stride(0), + output.stride(1), + output.stride(2), + output.stride(3), + batch_size, + num_heads, + seq_len, + head_dim, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + BLOCK_SIZE_K, + num_stages, + num_warps, + ) + + if return_lse: + raise NotImplementedError("return_lse=True is not supported yet.") + + softmax_params = (BLOCK_SIZE_softmax, num_warps_softmax, multi_block_launch) + return output, attn_weights, softmax_params + + +class LigerFusedNeighborhoodAttentionFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, query, key, value, kernel_size=7, dilation=1, scale=None): + output, attn_weights, softmax_params = fused_neighborhood_attention_forward( + query, key, value, kernel_size, dilation, scale + ) + ctx.save_for_backward(query, key, value, attn_weights) + ctx.kernel_size = kernel_size + ctx.dilation = dilation + ctx.scale = scale + ctx.softmax_params = softmax_params + return output + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output): + query, key, value, attn_weights = ctx.saved_tensors + BLOCK_SIZE_softmax, num_warps_softmax, multi_block_launch = ctx.softmax_params + + batch_size, num_heads, seq_len, head_dim = query.shape + scale = ctx.scale if ctx.scale is not None else 1.0 / math.sqrt(head_dim) + + grad_query = torch.zeros_like(query) + grad_key = torch.zeros_like(key) + grad_value = torch.zeros_like(value) + grad_attn_weights = torch.zeros_like(attn_weights) + + BLOCK_SIZE_M = min(64, triton.next_power_of_2(seq_len)) + BLOCK_SIZE_N = min(64, triton.next_power_of_2(seq_len)) + BLOCK_SIZE_K = min(64, triton.next_power_of_2(head_dim)) + num_stages = 4 if seq_len >= 512 else 2 + _, num_warps = calculate_settings(seq_len) + + grid_grad_attn = ( + batch_size * num_heads, + triton.cdiv(seq_len, BLOCK_SIZE_M), + triton.cdiv(seq_len, BLOCK_SIZE_N), + ) + _fused_neighborhood_attention_grad_attn_kernel[grid_grad_attn]( + grad_output, + value, + grad_attn_weights, + grad_output.stride(0), + grad_output.stride(1), + grad_output.stride(2), + grad_output.stride(3), + value.stride(0), + value.stride(1), + value.stride(2), + value.stride(3), + grad_attn_weights.stride(0), + grad_attn_weights.stride(1), + grad_attn_weights.stride(2), + grad_attn_weights.stride(3), + batch_size, + num_heads, + seq_len, + head_dim, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + BLOCK_SIZE_K, + num_stages, + num_warps, + ) + + grad_attn_reshaped = grad_attn_weights.view(batch_size * num_heads * seq_len, seq_len) + attn_reshaped = attn_weights.view(batch_size * num_heads * seq_len, seq_len) + + grad_qk_reshaped = _softmax_backward( + grad_attn_reshaped, attn_reshaped, BLOCK_SIZE_softmax, num_warps_softmax, multi_block_launch + ) + grad_qk_scores = grad_qk_reshaped.view(batch_size, num_heads, seq_len, seq_len) + + grid_grad_q = (batch_size * num_heads, triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(head_dim, BLOCK_SIZE_N)) + _fused_neighborhood_attention_grad_qk_kernel[grid_grad_q]( + grad_qk_scores, + key, + grad_query, + grad_qk_scores.stride(0), + grad_qk_scores.stride(1), + grad_qk_scores.stride(2), + grad_qk_scores.stride(3), + key.stride(0), + key.stride(1), + key.stride(2), + key.stride(3), + grad_query.stride(0), + grad_query.stride(1), + grad_query.stride(2), + grad_query.stride(3), + batch_size, + num_heads, + seq_len, + head_dim, + scale, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + BLOCK_SIZE_K, + num_stages, + num_warps, + ) + + grid_grad_k = (batch_size * num_heads, triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(head_dim, BLOCK_SIZE_N)) + _fused_neighborhood_attention_grad_k_kernel[grid_grad_k]( + grad_qk_scores, + query, + grad_key, + grad_qk_scores.stride(0), + grad_qk_scores.stride(1), + grad_qk_scores.stride(2), + grad_qk_scores.stride(3), + query.stride(0), + query.stride(1), + query.stride(2), + query.stride(3), + grad_key.stride(0), + grad_key.stride(1), + grad_key.stride(2), + grad_key.stride(3), + batch_size, + num_heads, + seq_len, + head_dim, + scale, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + BLOCK_SIZE_K, + num_stages, + num_warps, + ) + + grid_grad_v = (batch_size * num_heads, triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(head_dim, BLOCK_SIZE_N)) + _fused_neighborhood_attention_grad_v_kernel[grid_grad_v]( + attn_weights, + grad_output, + grad_value, + attn_weights.stride(0), + attn_weights.stride(1), + attn_weights.stride(2), + attn_weights.stride(3), + grad_output.stride(0), + grad_output.stride(1), + grad_output.stride(2), + grad_output.stride(3), + grad_value.stride(0), + grad_value.stride(1), + grad_value.stride(2), + grad_value.stride(3), + batch_size, + num_heads, + seq_len, + head_dim, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + BLOCK_SIZE_K, + num_stages, + num_warps, + ) + + return grad_query, grad_key, grad_value, None, None, None diff --git a/src/liger_kernel/ops/geglu.py b/src/liger_kernel/ops/geglu.py new file mode 100755 index 0000000000000000000000000000000000000000..6aa99c405797018d49974e4e357273d90bc413d9 --- /dev/null +++ b/src/liger_kernel/ops/geglu.py @@ -0,0 +1,143 @@ +import operator + +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.utils import calculate_settings +from liger_kernel.ops.utils import compare_version +from liger_kernel.ops.utils import ensure_contiguous +from liger_kernel.utils import is_npu_available + +if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available(): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import tanh + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import tanh +else: + from triton.language.math import tanh + + +@triton.jit +def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr): + program_id = tl.program_id(0).to(tl.int64) + + # locate start index + a += program_id * stride + b += program_id * stride + c += program_id * stride + + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32) + b_row = tl.load(b + col_offsets, mask=mask, other=0) + + # tanh approximation form of GELU is computed with: + # 0.5 * a * (1 + tanh(sqrt(2 / pi) * (a + 0.044715 * a^3))) + sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi) + a_cubed = a_row * a_row * a_row + tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed) + tanh_result = tanh(tanh_arg) + geglu_a = 0.5 * a_row * (1 + tanh_result) + c_row = geglu_a.cast(b_row.dtype) * b_row + tl.store(c + col_offsets, c_row, mask=mask) + + +@triton.jit +def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr): + program_id = tl.program_id(0).to(tl.int64) + + # locate start index + dc += program_id * stride + a += program_id * stride + b += program_id * stride + + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + dc_row = tl.load(dc + col_offsets, mask=mask, other=0) + a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32) + b_row = tl.load(b + col_offsets, mask=mask, other=0) + + # recomputation to save memory + sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi) + a_cubed = a_row * a_row * a_row + tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed) + tanh_result = tanh(tanh_arg) + geglu_a = 0.5 * a_row * (1 + tanh_result) + geglu_a = geglu_a.to(dc_row.dtype).to(tl.float32) + + db_row = dc_row.cast(tl.float32) * geglu_a + + # Gradient w.r.t. a can be computed with: + # b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2))) + # where z = sqrt(2/pi) * (a + 0.044715 * a^3) + term1 = 0.5 * (1 + tanh_result) + tanh_sq = tanh_result * tanh_result + term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row)) + da_row = dc_row * b_row * (term1 + term2) + + tl.store(a + col_offsets, da_row, mask=mask) + tl.store(b + col_offsets, db_row.to(dc_row.dtype), mask=mask) + + +def geglu_forward(a, b): + ori_shape = a.shape + + n_cols = ori_shape[-1] + a = a.view(-1, n_cols) + b = b.view(-1, n_cols) + c = torch.empty_like(a) + n_rows = a.shape[0] + + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + _geglu_tanh_forward_kernel[(n_rows,)]( + a, + b, + c, + c.stride(-2), + n_cols=n_cols, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + return a, b, c.view(*ori_shape) + + +def geglu_backward(a, b, dc): + ori_shape = dc.shape + n_cols = ori_shape[-1] + dc = dc.view(-1, n_cols) + n_rows = dc.shape[0] + + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + _geglu_tanh_backward_kernel[(n_rows,)]( + dc, + a, + b, + dc.stride(-2), + n_cols=n_cols, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + + return a.view(*ori_shape), b.view(*ori_shape) + + +class LigerGELUMulFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, a, b): + a, b, c = geglu_forward(a, b) + ctx.save_for_backward(a, b) + return c + + @staticmethod + @ensure_contiguous + def backward(ctx, dc): + a, b = ctx.saved_tensors + a, b = geglu_backward(a, b, dc) + return a, b diff --git a/src/liger_kernel/ops/group_norm.py b/src/liger_kernel/ops/group_norm.py new file mode 100755 index 0000000000000000000000000000000000000000..865fc337f77518562e0cac3587f9b8fb84c679e7 --- /dev/null +++ b/src/liger_kernel/ops/group_norm.py @@ -0,0 +1,311 @@ +import operator + +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.utils import compare_version +from liger_kernel.ops.utils import ensure_contiguous +from liger_kernel.utils import infer_device +from liger_kernel.utils import is_npu_available + +if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available(): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import rsqrt + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import rsqrt +else: + from triton.language.math import rsqrt + +if infer_device() == "npu": + MAX_FUSED_SIZE = 16384 # 8192 +else: + MAX_FUSED_SIZE = 65536 + + +@triton.jit +def _group_norm_forward_kernel( + Y_ptr, # pointer to output, shape (n_rows, n_groups, hidden_size) + Y_row_stride, # stride of each row in output + Y_col_stride, # stride of each column in output + X_ptr, # pointer to input, shape (n_rows, n_groups, hidden_size) + X_row_stride, # stride of each row in input + X_col_stride, # stride of each column in input + Mean_ptr, # pointer to mean, shape (n_rows, n_groups) + Mean_row_stride, # stride of each row in mean + Mean_col_stride, # stride of each column in mean + RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups) + RSTD_row_stride, # stride of each row in rstd + RSTD_col_stride, # stride of each column in rstd + W_ptr, # pointer to W + B_ptr, # pointer to B + hidden_size, # hidden size of X + channels_per_group, # the number of channels per group + eps, + BLOCK_SIZE: tl.constexpr, +): + """ + References: + https://nn.labml.ai/normalization/group_norm/index.html + """ + batch_idx = tl.program_id(0) + group_idx = tl.program_id(1) + + X_ptr += batch_idx * X_row_stride + group_idx * X_col_stride + Y_ptr += batch_idx * Y_row_stride + group_idx * Y_col_stride + + block_range = tl.arange(0, BLOCK_SIZE) + + # Compute mean and variance using the online algorithm + s = 0.0 + squared_sum = 0.0 + for i in tl.range(0, hidden_size, BLOCK_SIZE): + hidden_size_offsets = i + block_range + mask = hidden_size_offsets < hidden_size + X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0) + s += tl.sum(X) + # X**2 + squared_sum += tl.sum(X * X) + + m = s / hidden_size + + # variance = E[X**2] - E[X]**2 + variance = (squared_sum / hidden_size) - (m * m) + + # 1/std + rstd = rsqrt(variance + eps) + + # Normalize — flat loop over full hidden_size (not per-channel) + # This avoids the nested channel × per_channel_hidden loop where + # BLOCK_SIZE >> hidden_size_per_channel causes massive padding waste. + hidden_size_per_channel = hidden_size // channels_per_group + for i in tl.range(0, hidden_size, BLOCK_SIZE): + hidden_size_offsets = i + block_range + mask = hidden_size_offsets < hidden_size + X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m) + # Determine which channel each element belongs to, then load W/B + local_channel = hidden_size_offsets // hidden_size_per_channel + global_channel = group_idx * channels_per_group + local_channel + W = tl.load(W_ptr + global_channel, mask=mask) + B = tl.load(B_ptr + global_channel, mask=mask) + Y = (X - m) * rstd * W + B + tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask) + + tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m) + tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd) + + +@triton.jit +def _group_norm_backward_kernel( + X_ptr, # pointer to input, shape (n_rows, n_channels, hidden_size) + X_row_stride, # stride of each row in input + X_col_stride, # stride of each column in input + W_ptr, # pointer to weights, shape (n_channels) + Mean_ptr, # pointer to mean, shape (n_rows, n_groups) + Mean_ptr_row_stride, # stride of each column in mean + Mean_ptr_col_stride, # stride of each column in mean + RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups) + DX_ptr, # pointer to input grad, shape (n_rows, n_groups, hidden_size) + DW_ptr, # pointer to weights grad, shape (n_channels) + DB_ptr, # pointer to bias grad, shape (n_channels) + UPSTREAM_ptr, # pointer to output grad, shape (n_rows, n_channels, hidden_size) + hidden_size: tl.constexpr, # hidden size + channels_per_group: tl.constexpr, # number of groups in group norm + BLOCK_SIZE: tl.constexpr, + dtype: tl.constexpr, +): + """ + References: + https://nn.labml.ai/normalization/group_norm/index.html + https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md + + The backprop equations are the same for group_norm and layer_norm + the only difference here is that we load the Mean, Rstd corresponding to the + group we're computing gradients for and the mean and rstd are computed over n-channels + so the total number of elements we compute the mean over is num_channels_per_group * hidden_size + + We also need to load the Weights corresponding to the current channel to compute the gradients. + """ + batch_idx = tl.program_id(0) + group_idx = tl.program_id(1) + + # Move the pointers to the correct batch + X_ptr += batch_idx * X_row_stride + DX_ptr += batch_idx * X_row_stride + UPSTREAM_ptr += batch_idx * X_row_stride + + # Mean and rstd are the same shape so have the same strides + mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride) + rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride) + + c1 = 0.0 + c2 = 0.0 + block_range = tl.arange(0, BLOCK_SIZE) + + # We need to compute the sum terms of the backprop equations across all channels in the group + for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group): + dW = 0.0 + dB = 0.0 + # Move the pointers to the correct channel + W = tl.load(W_ptr + channel_idx) + for i in tl.range(0, hidden_size, BLOCK_SIZE): + hidden_size_offsets = i + block_range + mask = hidden_size_offsets < hidden_size + X = tl.load( + X_ptr + channel_idx * X_col_stride + hidden_size_offsets, + mask=mask, + other=0.0, + ) + UPSTREAM_grad = tl.load( + UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets, + mask=mask, + other=0.0, + ) + + x_hat = (X - mean) * rstd + dW += tl.sum(UPSTREAM_grad * x_hat) + dB += tl.sum(UPSTREAM_grad) + + wdy = W * UPSTREAM_grad + c1 += tl.sum(x_hat * wdy) + c2 += tl.sum(wdy) + + # Need to ensure additions to the same channel are atomic + tl.atomic_add(DW_ptr + channel_idx, dW.to(dtype)) + tl.atomic_add(DB_ptr + channel_idx, dB.to(dtype)) + + N = hidden_size * channels_per_group + c1 = c1 / N + c2 = c2 / N + + for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group): + # Move the pointers to the correct channel + W = tl.load(W_ptr + channel_idx) + for i in range(0, hidden_size, BLOCK_SIZE): + hidden_size_offsets = i + block_range + mask = hidden_size_offsets < hidden_size + X = tl.load( + X_ptr + channel_idx * X_col_stride + hidden_size_offsets, + mask=mask, + other=0.0, + ) + UPSTREAM_grad = tl.load( + UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets, + mask=mask, + other=0.0, + ) + + x_hat = (X - mean) * rstd + wdy = W * UPSTREAM_grad + dx = (wdy - (x_hat * c1 + c2)) * rstd + tl.store(DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask) + + +def group_norm_forward(X, num_channels, num_groups, W, B, eps): + shape = X.shape + batch_size = shape[0] + channels_per_group = num_channels // num_groups + # Reshape X so that the mean and std are computed across the groups + X = X.view(batch_size, num_groups, -1).contiguous() + hidden_size = X.shape[-1] + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size)) + Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device) + Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device) + RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device) + + _group_norm_forward_kernel[(batch_size, num_groups)]( + Y, + Y.stride(0), + Y.stride(1), + X, + X.stride(0), + X.stride(1), + Mean, + Mean.stride(0), + Mean.stride(1), + RSTD, + RSTD.stride(0), + RSTD.stride(1), + W, + B, + hidden_size, + channels_per_group, + eps, + BLOCK_SIZE=BLOCK_SIZE, + ) + # Return tensors in the original shape + return Y.view(*shape), X.view(*shape), Mean, RSTD, BLOCK_SIZE + + +def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): + shape = dY.shape + batch_size = shape[0] + hidden_size = dY.shape[-1] + channels_per_group = num_channels // num_groups + dY = dY.view(batch_size, num_groups, -1) + DX = torch.empty( + (batch_size, num_groups, hidden_size * channels_per_group), + dtype=X.dtype, + device=X.device, + ) + DW = torch.zeros((num_channels), dtype=W.dtype, device=W.device) + DB = torch.zeros((num_channels), dtype=B.dtype, device=B.device) + triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16 + + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size)) + _group_norm_backward_kernel[(batch_size, num_groups)]( + X, + X.stride(0), + X.stride(1), + W, + Mean, + Mean.stride(0), + Mean.stride(1), + RSTD, + DX, + DW, + DB, + dY, + hidden_size, + channels_per_group, + BLOCK_SIZE=BLOCK_SIZE, + dtype=triton_dtype, + ) + + # Return tensors in the original shape + return DX.view(*shape), DW, DB + + +class LigerGroupNormFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward( + ctx, + X, + affine_scaling_weight, + affine_shifting_bias, + num_channels, + num_groups, + eps, + ): + Y, X, Mean, RSTD, BLOCK_SIZE = group_norm_forward( + X, + num_channels, + num_groups, + affine_scaling_weight, + affine_shifting_bias, + eps, + ) + ctx.num_channels = num_channels + ctx.num_groups = num_groups + ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD) + return Y + + @staticmethod + @ensure_contiguous + def backward(ctx, dY): + X, W, B, Mean, RSTD = ctx.saved_tensors + DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups) + return DX, DW, DB, None, None, None diff --git a/src/liger_kernel/ops/grpo_loss.py b/src/liger_kernel/ops/grpo_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..bcd1262bc84f0966710e381ca9c434a0b4974f42 --- /dev/null +++ b/src/liger_kernel/ops/grpo_loss.py @@ -0,0 +1,930 @@ +import torch +import triton +import triton.language as tl + +# Loss type constants for Triton constexpr branching +# GRPO/DAPO/BNPO/DR_GRPO all use the same per-token loss computation (standard PPO clipping) +_LOSS_TYPE_GRPO: tl.constexpr = tl.constexpr(0) +_LOSS_TYPE_CISPO: tl.constexpr = tl.constexpr(1) +_LOSS_TYPE_SAPO: tl.constexpr = tl.constexpr(2) + +_str_to_loss_type = { + "grpo": _LOSS_TYPE_GRPO.value, + "dapo": _LOSS_TYPE_GRPO.value, + "bnpo": _LOSS_TYPE_GRPO.value, + "dr_grpo": _LOSS_TYPE_GRPO.value, + "luspo": _LOSS_TYPE_GRPO.value, + "cispo": _LOSS_TYPE_CISPO.value, + "sapo": _LOSS_TYPE_SAPO.value, +} + + +@triton.jit +def _selective_log_softmax_kernel( + LOGITS, + INPUT_IDS, + LOG_P, + MASK, + TEMPERATURE, + stride_input_ids_b, + L: tl.constexpr, + N: tl.constexpr, + BLOCK_N: tl.constexpr = 4096, +): + off_b = tl.program_id(0).cast(tl.int64) + off_l = tl.program_id(1).cast(tl.int64) + + LOGITS += off_b * (L + 1) * N + off_l * N + INPUT_IDS += off_b * stride_input_ids_b + off_l + LOG_P += off_b * L + off_l + + if MASK is not None: + MASK += off_b * stride_input_ids_b + off_l + not_skip = tl.load(MASK) + if not_skip == 0: + return + + m_i = float("-inf") + l_i = 0.0 + for start in range(0, N, BLOCK_N): + cols = start + tl.arange(0, BLOCK_N) + logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE + new_m_i = tl.maximum(m_i, tl.max(logits)) + alpha = tl.exp(m_i - new_m_i) + l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i)) + m_i = new_m_i + lse = m_i + tl.log(l_i) + + ids = tl.load(INPUT_IDS) + x = tl.load(LOGITS + ids).to(tl.float32) / TEMPERATURE + logp = x - lse + tl.store(LOG_P, logp) + + +# compue old_logp and ref_logp, it reduce 10G peak Memory. it does not requires grad +@torch.no_grad +def fused_selective_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 0.9, mask=None): + assert logits.is_contiguous() + B, L_ADD_1, N = logits.shape + L = L_ADD_1 - 1 + input_ids = input_ids[:, -L:] + if mask is not None: + mask = mask[:, -L:] + log_p = torch.zeros(B, L, dtype=torch.float32, device=logits.device) + kwargs = {"BLOCK_N": 2048, "num_stages": 4, "num_warps": 1} + _selective_log_softmax_kernel[(B, L)]( + logits, input_ids, log_p, mask, temperature, input_ids.stride(0), L, N, **kwargs + ) + return log_p + + +# @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw) +# for BLOCK_N in [2048, 4096, 8192] +# for ns in [1, 2, 4] +# for nw in [1, 2, 4, 8, 16]], +# key=['N']) +@triton.jit +def _grpo_loss_fwd_kernel( + LOGITS, + OLD_LOGP, + REF_LOGP, + INPUT_IDS, + COMPLETION_MASK, + ADVANTAGES, + VLLM_IS_RATIO, + VLLM_IS_RATIO_STRIDE, + LOSS, + LSE, + KL, + IS_CLIPPED, + TEMPERATURE, + BETA: tl.constexpr, + EPS_LOW, + EPS_HIGH, + LOSS_TYPE: tl.constexpr, + SAPO_TEMP_POS, + SAPO_TEMP_NEG, + DELTA, + USE_BIAS_CORRECTION_KL: tl.constexpr, + L: tl.constexpr, + N: tl.constexpr, + BLOCK_N: tl.constexpr = 4096, +): + off_b = tl.program_id(0).cast(tl.int64) + off_l = tl.program_id(1).cast(tl.int64) + + if COMPLETION_MASK is not None: + COMPLETION_MASK += off_b * L + off_l + not_skip = tl.load(COMPLETION_MASK) + if not_skip == 0: + return + + LOGITS += off_b * (L + 1) * N + off_l * N + INPUT_IDS += off_b * L + off_l + ADVANTAGES += off_b + LOSS += off_b * L + off_l + LSE += off_b * L + off_l + IS_CLIPPED += off_b * L + off_l + + m_i = float("-inf") + l_i = 0.0 + for start in range(0, N, BLOCK_N): + cols = start + tl.arange(0, BLOCK_N) + logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE + new_m_i = tl.maximum(m_i, tl.max(logits)) + alpha = tl.exp(m_i - new_m_i) + l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i)) + m_i = new_m_i + lse = m_i + tl.log(l_i) + + idx = tl.load(INPUT_IDS) + x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE + logp = x - lse + if OLD_LOGP is None: + old_logp = logp + else: + OLD_LOGP += off_b * L + off_l + old_logp = tl.load(OLD_LOGP).to(tl.float32) + coef_1 = tl.exp(logp - old_logp) + advantage = tl.load(ADVANTAGES).to(tl.float32) + + # Branch based on loss type + if LOSS_TYPE == 0: # GRPO/DAPO/BNPO/DR_GRPO: standard PPO clipping + coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH) + is_low_clipped = (coef_1 < 1 - EPS_LOW) & (advantage < 0) + is_high_clipped = (coef_1 > 1 + EPS_HIGH) & (advantage > 0) + is_clipped = is_low_clipped | is_high_clipped + # Apply delta (two-sided clipping from INTELLECT-2) to coef_1 + if DELTA != 0.0: + coef_1 = tl.minimum(coef_1, DELTA) + per_token_loss1 = coef_1 * advantage + per_token_loss2 = coef_2 * advantage + per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2) + + elif LOSS_TYPE == 1: # CISPO: upper-bound only clipping, detached, multiply by logp + # Reference: MiniMax-M1 technical report + # https://github.com/huggingface/trl/blob/035c3ff151b953ca72cdfe0ee966bc1469a26fde/trl/trainer/grpo_trainer.py#L2030 + coef_2 = tl.minimum(coef_1, EPS_HIGH) # upper-bound only (EPS_HIGH is the raw bound for CISPO) + per_token_loss = -coef_2 * advantage * logp # includes logp term + is_clipped = (coef_1 > EPS_HIGH) & (advantage > 0) + + elif LOSS_TYPE == 2: # SAPO: soft adaptive policy optimization with sigmoid gating + # Reference: https://huggingface.co/papers/2511.20347 + # Formula: sigmoid(τ * (ρ - 1)) * 4 / τ + temperature = tl.where(advantage > 0, SAPO_TEMP_POS, SAPO_TEMP_NEG) + sigmoid_input = temperature * (coef_1 - 1.0) + sapo_coef = tl.sigmoid(sigmoid_input) * 4.0 / temperature + per_token_loss = -sapo_coef * advantage + is_clipped = 0.0 # SAPO has no clipping concept + + # Apply vLLM importance sampling correction BEFORE adding KL penalty + if VLLM_IS_RATIO is not None: + # Use modulo to support both (B, L) per-token and (B, 1) per-sequence shapes + vllm_is_ratio = tl.load(VLLM_IS_RATIO + off_b * VLLM_IS_RATIO_STRIDE + off_l % VLLM_IS_RATIO_STRIDE).to( + tl.float32 + ) + per_token_loss = per_token_loss * vllm_is_ratio + + if BETA != 0.0: + REF_LOGP += off_b * L + off_l + KL += off_b * L + off_l + ref_logp = tl.load(REF_LOGP).to(tl.float32) + kl = tl.exp(ref_logp - logp) - (ref_logp - logp) - 1 + if USE_BIAS_CORRECTION_KL: + # Importance-sampling-corrected KL (DeepSeek-V3.2): kl *= coef_1 + kl = kl * tl.exp(logp - old_logp) + per_token_loss += BETA * kl + tl.store(KL, kl) + + tl.store(LOSS, per_token_loss) + tl.store(LSE, lse) + tl.store(IS_CLIPPED, is_clipped) + + +# Sequence-level forward kernel: uses pre-computed coef_1 per sequence +@triton.jit +def _grpo_loss_fwd_kernel_seq( + LOGITS, + OLD_LOGP, + REF_LOGP, + INPUT_IDS, + COMPLETION_MASK, + ADVANTAGES, + COEF_1, # Pre-computed sequence-level importance weight (B,) + COEF_2, # Pre-computed clipped coef (B,) + IS_CLIPPED_SEQ, # Pre-computed clipping indicator (B,) + VLLM_IS_RATIO, # vLLM importance sampling ratio (B, L) or (B, 1) or None + VLLM_IS_RATIO_STRIDE, # stride for VLLM_IS_RATIO (L for per-token, 1 for per-sequence) + LOSS, + LSE, + KL, + IS_CLIPPED, + TEMPERATURE, + BETA: tl.constexpr, + USE_BIAS_CORRECTION_KL: tl.constexpr, + L: tl.constexpr, + N: tl.constexpr, + BLOCK_N: tl.constexpr = 4096, +): + off_b = tl.program_id(0).cast(tl.int64) + off_l = tl.program_id(1).cast(tl.int64) + + if COMPLETION_MASK is not None: + COMPLETION_MASK += off_b * L + off_l + not_skip = tl.load(COMPLETION_MASK) + if not_skip == 0: + return + + LOGITS += off_b * (L + 1) * N + off_l * N + INPUT_IDS += off_b * L + off_l + ADVANTAGES += off_b + COEF_1 += off_b + COEF_2 += off_b + IS_CLIPPED_SEQ += off_b + LOSS += off_b * L + off_l + LSE += off_b * L + off_l + IS_CLIPPED += off_b * L + off_l + + # Compute log softmax + m_i = float("-inf") + l_i = 0.0 + for start in range(0, N, BLOCK_N): + cols = start + tl.arange(0, BLOCK_N) + logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE + new_m_i = tl.maximum(m_i, tl.max(logits)) + alpha = tl.exp(m_i - new_m_i) + l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i)) + m_i = new_m_i + lse = m_i + tl.log(l_i) + + idx = tl.load(INPUT_IDS) + x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE + logp = x - lse + + # Load pre-computed sequence-level coefficients + coef_1 = tl.load(COEF_1).to(tl.float32) + coef_2 = tl.load(COEF_2).to(tl.float32) + is_clipped_seq = tl.load(IS_CLIPPED_SEQ) + + advantage = tl.load(ADVANTAGES).to(tl.float32) + per_token_loss1 = coef_1 * advantage + per_token_loss2 = coef_2 * advantage + per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2) + + # Apply vLLM importance sampling correction BEFORE adding KL + if VLLM_IS_RATIO is not None: + vllm_is_ratio = tl.load(VLLM_IS_RATIO + off_b * VLLM_IS_RATIO_STRIDE + off_l % VLLM_IS_RATIO_STRIDE).to( + tl.float32 + ) + per_token_loss = per_token_loss * vllm_is_ratio + + if BETA != 0.0: + REF_LOGP += off_b * L + off_l + KL += off_b * L + off_l + ref_logp = tl.load(REF_LOGP).to(tl.float32) + kl = tl.exp(ref_logp - logp) - (ref_logp - logp) - 1 + if USE_BIAS_CORRECTION_KL: + # Importance-sampling-corrected KL (DeepSeek-V3.2): kl *= token-level coef_1 + if OLD_LOGP is None: + old_logp = logp + else: + old_logp = tl.load(OLD_LOGP + off_b * L + off_l).to(tl.float32) + kl = kl * tl.exp(logp - old_logp) + per_token_loss += BETA * kl + tl.store(KL, kl) + + tl.store(LOSS, per_token_loss) + tl.store(LSE, lse) + tl.store(IS_CLIPPED, is_clipped_seq) # Same for all tokens in sequence + + +# Sequence-level backward kernel +@triton.jit +def _grpo_loss_bwd_kernel_seq( + DLOSS, + DLOSS_SUM, + DLOGITS, + LOGITS, + OLD_LOGP, + REF_LOGP, + INPUT_IDS, + ADVANTAGES, + COMPLETION_MASK, + LSE, + COEF_1, # Pre-computed sequence-level importance weight (B,) + SEQ_LEN, # Number of valid tokens per sequence (B,) + TEMPERATURE, + BETA: tl.constexpr, + USE_BIAS_CORRECTION_KL: tl.constexpr, + EPS_LOW, + EPS_HIGH, + DELTA, + loss_stride0, + loss_stride1, + L: tl.constexpr, + N: tl.constexpr, + BLOCK_N: tl.constexpr = 4096, +): + off_b = tl.program_id(0).cast(tl.int64) + off_l = tl.program_id(1).cast(tl.int64) + + DLOGITS += off_b * (L + 1) * N + off_l * N + if COMPLETION_MASK is not None: + COMPLETION_MASK += off_b * L + off_l + not_skip = tl.load(COMPLETION_MASK) + if not_skip == 0: + for start in range(0, N, BLOCK_N): + cols = tl.arange(0, BLOCK_N) + start + tl.store(DLOGITS + cols, 0.0, mask=cols < N) + return + + LOGITS += off_b * (L + 1) * N + off_l * N + DLOSS += off_b * loss_stride0 + off_l * loss_stride1 + DLOSS_SUM += off_b + INPUT_IDS += off_b * L + off_l + ADVANTAGES += off_b + LSE += off_b * L + off_l + COEF_1 += off_b + SEQ_LEN += off_b + + dloss = tl.load(DLOSS).to(tl.float32) + dloss_sum = tl.load(DLOSS_SUM).to(tl.float32) + lse = tl.load(LSE).to(tl.float32) + coef_1 = tl.load(COEF_1).to(tl.float32) + seq_len = tl.load(SEQ_LEN).to(tl.float32) + + idx = tl.load(INPUT_IDS) + x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE + logp = x - lse + + advantage = tl.load(ADVANTAGES).to(tl.float32) + coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH) + if DELTA != 0.0: + coef_1_for_loss = tl.minimum(coef_1, DELTA) + else: + coef_1_for_loss = coef_1 + per_token_loss1 = coef_1_for_loss * advantage + per_token_loss2 = coef_2 * advantage + is_unclipped = per_token_loss2 >= per_token_loss1 + + # For sequence-level: gradient flows through mean, so scale by coef_1/seq_len + # d(loss)/d(logp) = -advantage * coef_1 / seq_len (when unclipped and not delta-clamped) + dlogp = -coef_1 * advantage / seq_len * is_unclipped * dloss_sum + if DELTA != 0.0: + dlogp = dlogp * (coef_1 <= DELTA) + + if BETA != 0.0: + REF_LOGP += off_b * L + off_l + ref_logp = tl.load(REF_LOGP).to(tl.float32) + if USE_BIAS_CORRECTION_KL: + # d(kl * coef_1)/d(logp) = coef_1 * (logp - ref_logp), where coef_1 = exp(logp - old_logp) + if OLD_LOGP is None: + old_logp = logp + else: + old_logp = tl.load(OLD_LOGP + off_b * L + off_l).to(tl.float32) + token_coef_1 = tl.exp(logp - old_logp) + dlogp += BETA * token_coef_1 * (logp - ref_logp) * dloss + else: + dlogp += BETA * (1 - tl.exp(ref_logp - logp)) * dloss + + dlogp = dlogp / TEMPERATURE + tl.debug_barrier() + for start_n in tl.range(0, N, BLOCK_N): + cols = start_n + tl.arange(0, BLOCK_N) + logits = tl.load(LOGITS + cols, mask=cols < N, other=-float("inf")).to(tl.float32) / TEMPERATURE + probs = tl.exp(logits - lse) + dlogits = tl.where(cols == idx, 1 - probs, -probs) * dlogp + tl.store(DLOGITS + cols, dlogits, mask=cols < N) + + +@triton.jit +def _grpo_loss_bwd_kernel( + DLOSS, + DLOGITS, + LOGITS, + OLD_LOGP, + REF_LOGP, + INPUT_IDS, + ADVANTAGES, + COMPLETION_MASK, + LSE, + VLLM_IS_RATIO, + VLLM_IS_RATIO_STRIDE, + TEMPERATURE, + BETA: tl.constexpr, + EPS_LOW, + EPS_HIGH, + LOSS_TYPE: tl.constexpr, + SAPO_TEMP_POS, + SAPO_TEMP_NEG, + DELTA, + USE_BIAS_CORRECTION_KL: tl.constexpr, + loss_stride0, + loss_stride1, + L: tl.constexpr, + N: tl.constexpr, + BLOCK_N: tl.constexpr = 4096, +): + off_b = tl.program_id(0).cast(tl.int64) + off_l = tl.program_id(1).cast(tl.int64) + + DLOGITS += off_b * (L + 1) * N + off_l * N + if COMPLETION_MASK is not None: + COMPLETION_MASK += off_b * L + off_l + not_skip = tl.load(COMPLETION_MASK) + if not_skip == 0: + for start in range(0, N, BLOCK_N): + cols = tl.arange(0, BLOCK_N) + start + tl.store(DLOGITS + cols, 0.0, mask=cols < N) + return + + LOGITS += off_b * (L + 1) * N + off_l * N + DLOSS += off_b * loss_stride0 + off_l * loss_stride1 + INPUT_IDS += off_b * L + off_l + ADVANTAGES += off_b + LSE += off_b * L + off_l + + dloss = tl.load(DLOSS).to(tl.float32) + lse = tl.load(LSE).to(tl.float32) + + idx = tl.load(INPUT_IDS) + x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE + logp = x - lse + if OLD_LOGP is None: + old_logp = logp + else: + OLD_LOGP += off_b * L + off_l + old_logp = tl.load(OLD_LOGP).to(tl.float32) + coef_1 = tl.exp(logp - old_logp) + advantage = tl.load(ADVANTAGES).to(tl.float32) + + # Branch based on loss type for gradient computation + if LOSS_TYPE == 0: # GRPO/DAPO/BNPO/DR_GRPO: standard PPO clipping + coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH) + if DELTA != 0.0: + coef_1_for_loss = tl.minimum(coef_1, DELTA) + else: + coef_1_for_loss = coef_1 + per_token_loss1 = coef_1_for_loss * advantage + per_token_loss2 = coef_2 * advantage + mask = per_token_loss2 >= per_token_loss1 + # Gradient uses original coef_1; zero when delta-clamped (constant → no gradient) + dlogp = -coef_1 * advantage * mask + if DELTA != 0.0: + dlogp = dlogp * (coef_1 <= DELTA) + + elif LOSS_TYPE == 1: # CISPO: coef_2 is DETACHED, so gradient only flows through logp + # loss = -coef_2 * advantage * logp, where coef_2 = clamp(coef_1, max=eps_high).detach() + # d(loss)/d(logp) = -coef_2 * advantage (coef_2 treated as constant due to detach) + coef_2 = tl.minimum(coef_1, EPS_HIGH) + dlogp = -coef_2 * advantage + + elif LOSS_TYPE == 2: # SAPO: gradient through sigmoid gating + # loss = -sapo_coef * advantage, where sapo_coef = sigmoid(τ*(ρ-1)) * 4/τ + # d(loss)/d(logp) = -advantage * d(sapo_coef)/d(coef_1) * d(coef_1)/d(logp) + # d(coef_1)/d(logp) = coef_1 (since coef_1 = exp(logp - old_logp)) + # d(sapo_coef)/d(coef_1) = d/d(coef_1)[sigmoid(τ*(coef_1-1)) * 4/τ] + # = τ * sigmoid' * 4/τ = 4 * sigmoid * (1 - sigmoid) + # (the τ factors cancel out in the derivative) + temperature = tl.where(advantage > 0, SAPO_TEMP_POS, SAPO_TEMP_NEG) + sigmoid_input = temperature * (coef_1 - 1.0) + sigmoid_val = tl.sigmoid(sigmoid_input) + d_sapo_d_coef1 = 4.0 * sigmoid_val * (1.0 - sigmoid_val) + dlogp = -advantage * d_sapo_d_coef1 * coef_1 + + # Apply vLLM IS ratio to PPO gradient (before KL gradient) + if VLLM_IS_RATIO is not None: + # Use modulo to support both (B, L) per-token and (B, 1) per-sequence shapes + vllm_is_ratio = tl.load(VLLM_IS_RATIO + off_b * VLLM_IS_RATIO_STRIDE + off_l % VLLM_IS_RATIO_STRIDE).to( + tl.float32 + ) + dlogp = dlogp * vllm_is_ratio + + if BETA != 0.0: + REF_LOGP += off_b * L + off_l + ref_logp = tl.load(REF_LOGP).to(tl.float32) + if USE_BIAS_CORRECTION_KL: + # d(kl * coef_1)/d(logp) = coef_1 * (logp - ref_logp), where coef_1 = exp(logp - old_logp) + dlogp += BETA * coef_1 * (logp - ref_logp) + else: + dlogp += BETA * (1 - tl.exp(ref_logp - logp)) + + dlogp = dlogp * dloss / TEMPERATURE + tl.debug_barrier() + for start_n in tl.range(0, N, BLOCK_N): + cols = start_n + tl.arange(0, BLOCK_N) + logits = tl.load(LOGITS + cols, mask=cols < N, other=-float("inf")).to(tl.float32) / TEMPERATURE + probs = tl.exp(logits - lse) + dlogits = tl.where(cols == idx, 1 - probs, -probs) * dlogp + tl.store(DLOGITS + cols, dlogits, mask=cols < N) + + +def _compute_dapo_normalizer(completion_mask): + """Global active tokens averaged per process (for distributed DAPO loss).""" + normalizer = completion_mask.to(torch.float32).sum() + world_size = 1 + if torch.distributed.is_available() and torch.distributed.is_initialized(): + normalizer = normalizer.clone() + torch.distributed.all_reduce(normalizer, op=torch.distributed.ReduceOp.SUM) + world_size = torch.distributed.get_world_size() + normalizer = normalizer / world_size + return torch.clamp(normalizer, min=1.0) + + +def _reduce_loss(per_token_loss, mask, loss_type, max_completion_length, B, L): + """Apply loss reduction based on loss_type.""" + if loss_type == "grpo" or loss_type == "sapo": + return ((per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean() + elif loss_type == "bnpo": + return (per_token_loss * mask).sum() / mask.sum().clamp(min=1.0) + elif loss_type == "dr_grpo": + max_len = max_completion_length if max_completion_length is not None else L + return (per_token_loss * mask).sum() / (B * max_len) + elif loss_type == "dapo" or loss_type == "cispo": + return (per_token_loss * mask).sum() / _compute_dapo_normalizer(mask) + elif loss_type == "luspo": + return (per_token_loss * mask.sum(-1, keepdim=True)).mean() + raise ValueError(f"Unknown loss_type: {loss_type}. Expected one of: grpo, bnpo, dr_grpo, dapo, cispo, sapo, luspo") + + +class GrpoLossFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + inplace, + loss_type="grpo", + max_completion_length=None, + reduce=True, + importance_sampling_level="token", + sapo_temperature_pos=1.0, + sapo_temperature_neg=1.05, + vllm_is_ratio=None, + delta=None, + use_bias_correction_kl=False, + ): + assert logits.is_contiguous() and completion_ids.is_contiguous() + assert old_logp is None or old_logp.is_contiguous() + assert (ref_logp is not None and ref_logp.is_contiguous()) if beta != 0.0 else True + assert importance_sampling_level in ("token", "sequence"), ( + f"importance_sampling_level must be 'token' or 'sequence', got {importance_sampling_level}" + ) + + # Validate loss_type + if loss_type not in _str_to_loss_type: + raise ValueError(f"Unknown loss_type '{loss_type}'. Supported types: {list(_str_to_loss_type.keys())}") + + # Validate delta + loss_type combinations + if delta is not None and loss_type in ("cispo", "sapo"): + raise ValueError(f"delta (two-sided clipping) is not supported for loss_type='{loss_type}'.") + + # Map delta to float for Triton (Triton can't handle None) + delta_val = 0.0 if delta is None else float(delta) + + # Validate sequence-level + loss_type combinations + if importance_sampling_level == "sequence" and loss_type in ("cispo", "sapo"): + raise ValueError( + f"Sequence-level importance sampling is not supported for loss_type='{loss_type}'. " + f"Use importance_sampling_level='token' instead." + ) + + # Validate SAPO temperatures to prevent division by zero or numerical instability + if loss_type == "sapo": + if sapo_temperature_pos <= 0: + raise ValueError(f"sapo_temperature_pos must be positive, got {sapo_temperature_pos}") + if sapo_temperature_neg <= 0: + raise ValueError(f"sapo_temperature_neg must be positive, got {sapo_temperature_neg}") + + # Convert loss_type string to integer for Triton constexpr + loss_type_int = _str_to_loss_type[loss_type] + + B, L_ADD_1, N = logits.shape + L = L_ADD_1 - 1 + + if completion_mask is not None: + assert completion_mask.is_contiguous() + + mask = completion_mask.float() if completion_mask is not None else torch.ones(B, L, device=logits.device) + + # Handle vLLM IS ratio + vllm_is_ratio_ptr = None + vllm_is_ratio_stride = L # default to per-token (unused when ptr is None) + if vllm_is_ratio is not None: + assert vllm_is_ratio.dim() in (1, 2), ( + f"vllm_is_ratio must be 1D (B,) or 2D (B, L) / (B, 1), got {vllm_is_ratio.dim()}D" + ) + if vllm_is_ratio.dim() == 2: + assert vllm_is_ratio.shape[0] == B and vllm_is_ratio.shape[1] in (1, L), ( + f"vllm_is_ratio shape must be ({B}, 1) or ({B}, {L}), got {tuple(vllm_is_ratio.shape)}" + ) + else: + assert vllm_is_ratio.shape[0] == B, ( + f"vllm_is_ratio shape must be ({B},), got {tuple(vllm_is_ratio.shape)}" + ) + vllm_is_ratio = vllm_is_ratio.contiguous() + vllm_is_ratio_ptr = vllm_is_ratio + vllm_is_ratio_stride = vllm_is_ratio.shape[1] if vllm_is_ratio.dim() > 1 else 1 + + # Allocate outputs + loss = torch.zeros(B, L, device=logits.device, dtype=torch.float32) + lse = torch.zeros_like(loss) + is_clipped = torch.zeros_like(loss) + kl = torch.zeros_like(loss) if beta != 0.0 else None + + if importance_sampling_level == "sequence": + # Sequence-level: pre-compute sequence importance weights, then use Triton kernel + # Step 1: Get per-token log probs using existing Triton kernel + per_token_logps = fused_selective_log_softmax(logits, completion_ids, temperature, completion_mask) + + # Step 2: Compute sequence-level importance weights + if old_logp is None: + log_ratio = torch.zeros_like(per_token_logps) + else: + log_ratio = per_token_logps - old_logp + + seq_lens = mask.sum(-1).clamp(min=1.0) # (B,) + seq_log_importance = (log_ratio * mask).sum(-1) / seq_lens # (B,) + coef_1 = torch.exp(seq_log_importance) # (B,) + coef_2 = torch.clamp(coef_1, 1 - eps_low, 1 + eps_high) # (B,) + + # Compute is_clipped at sequence level (using original coef_1) + is_clipped_seq = ((coef_1 < 1 - eps_low) & (advantages < 0)) | ((coef_1 > 1 + eps_high) & (advantages > 0)) + is_clipped_seq = is_clipped_seq.float() # (B,) + + # Apply delta clamp for loss computation (keep original coef_1 for backward) + if delta is not None: + coef_1_for_loss = torch.clamp(coef_1, max=delta) + else: + coef_1_for_loss = coef_1 + + # Step 3: Run Triton kernel with pre-computed coefficients + kwargs = {"BLOCK_N": 2048, "num_stages": 2, "num_warps": 1} + _grpo_loss_fwd_kernel_seq[(B, L)]( + logits, + old_logp, + ref_logp, + completion_ids, + completion_mask, + advantages, + coef_1_for_loss.contiguous(), + coef_2.contiguous(), + is_clipped_seq.contiguous(), + vllm_is_ratio_ptr, + vllm_is_ratio_stride, + loss, + lse, + kl, + is_clipped, + temperature, + beta, + use_bias_correction_kl, + L, + N, + **kwargs, + ) + + # Save extra tensors for backward + ctx.save_for_backward( + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + lse, + mask, + coef_1, + seq_lens, + vllm_is_ratio_ptr, + ) + else: + # Token-level: use optimized Triton kernel with LOSS_TYPE branching + kwargs = {"BLOCK_N": 2048, "num_stages": 2, "num_warps": 1} + _grpo_loss_fwd_kernel[(B, L)]( + logits, + old_logp, + ref_logp, + completion_ids, + completion_mask, + advantages, + vllm_is_ratio_ptr, + vllm_is_ratio_stride, + loss, + lse, + kl, + is_clipped, + temperature, + beta, + eps_low, + eps_high, + loss_type_int, + sapo_temperature_pos, + sapo_temperature_neg, + delta_val, + use_bias_correction_kl, + L, + N, + **kwargs, + ) + ctx.save_for_backward( + logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse, mask, vllm_is_ratio_ptr + ) + + ctx.infos = ( + temperature, + beta, + eps_low, + eps_high, + inplace, + loss_type, + loss_type_int, + sapo_temperature_pos, + sapo_temperature_neg, + max_completion_length, + B, + L, + importance_sampling_level, + vllm_is_ratio_stride, + reduce, + delta_val, + use_bias_correction_kl, + ) + + # Compute metrics before reduction + mask_sum = mask.sum().clamp(min=1.0) + kl_mean = (kl * mask).sum() / mask_sum if kl is not None else None + clip_ratio = (is_clipped.float() * mask).sum() / mask_sum + + if not reduce: + loss_out = loss * mask + kl_out = kl * mask if kl is not None else None + is_clipped_out = is_clipped * mask + return loss_out, kl_out, is_clipped_out + + reduced_loss = _reduce_loss(loss, mask, loss_type, max_completion_length, B, L) + return reduced_loss, kl_mean, clip_ratio + + @staticmethod + def backward(ctx, *args): + dloss_input = args[0] + saved_tensors = ctx.saved_tensors + ( + temperature, + beta, + eps_low, + eps_high, + inplace, + loss_type, + loss_type_int, + sapo_temperature_pos, + sapo_temperature_neg, + max_completion_length, + B, + L, + importance_sampling_level, + vllm_is_ratio_stride, + reduce, + delta_val, + use_bias_correction_kl, + ) = ctx.infos + + if importance_sampling_level == "sequence": + ( + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + lse, + mask, + coef_1, + seq_lens, + vllm_is_ratio, + ) = saved_tensors + else: + (logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse, mask, vllm_is_ratio) = ( + saved_tensors + ) + + _, L_ADD_1, N = logits.shape + + # Compute per-token gradient scaling based on loss_type + if not reduce: + dloss = dloss_input + elif loss_type == "grpo" or loss_type == "sapo": + seq_lens_bwd = mask.sum(-1, keepdim=True).clamp(min=1.0) + dloss = dloss_input * mask / (seq_lens_bwd * B) + elif loss_type == "bnpo": + dloss = dloss_input * mask / mask.sum().clamp(min=1.0) + elif loss_type == "dr_grpo": + max_len = max_completion_length if max_completion_length is not None else L + dloss = dloss_input * mask / (B * max_len) + elif loss_type == "dapo" or loss_type == "cispo": + dloss = dloss_input * mask / _compute_dapo_normalizer(mask) + elif loss_type == "luspo": + # loss = mean(per_token_loss * seq_lens), mean divides by B*L + seq_lens_bwd = mask.sum(-1, keepdim=True).clamp(min=1.0) + dloss = dloss_input * seq_lens_bwd / (B * L) + else: + raise ValueError(f"Unknown loss_type: {loss_type}") + + dlogits = logits.data if inplace else torch.empty_like(logits) + kwargs = {"BLOCK_N": 4096, "num_stages": 1, "num_warps": 16} + + if importance_sampling_level == "sequence": + if vllm_is_ratio is None: + dloss_sum = dloss.sum(-1).contiguous() + else: + if vllm_is_ratio.dim() == 1: + ratio = vllm_is_ratio.unsqueeze(-1) + else: + ratio = vllm_is_ratio + dloss_sum = (dloss * ratio).sum(-1).contiguous() + # Sequence-level backward kernel + _grpo_loss_bwd_kernel_seq[(B, L)]( + dloss, + dloss_sum, + dlogits, + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + lse, + coef_1, + seq_lens, + temperature, + beta, + use_bias_correction_kl, + eps_low, + eps_high, + delta_val, + *dloss.stride(), + L, + N, + **kwargs, + ) + else: + # Token-level backward kernel with LOSS_TYPE branching + _grpo_loss_bwd_kernel[(B, L)]( + dloss, + dlogits, + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + lse, + vllm_is_ratio, + vllm_is_ratio_stride, + temperature, + beta, + eps_low, + eps_high, + loss_type_int, + sapo_temperature_pos, + sapo_temperature_neg, + delta_val, + use_bias_correction_kl, + *dloss.stride(), + L, + N, + **kwargs, + ) + + dlogits[:, -1, :] = 0 + # Return gradients for all forward inputs: dlogits + 19 None for non-differentiable params + return ( + dlogits, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) diff --git a/src/liger_kernel/ops/jsd.py b/src/liger_kernel/ops/jsd.py new file mode 100755 index 0000000000000000000000000000000000000000..3115a254913b694745a940e73bdd6296c90ff5cd --- /dev/null +++ b/src/liger_kernel/ops/jsd.py @@ -0,0 +1,201 @@ +from typing import Optional + +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.utils import ensure_contiguous +from liger_kernel.utils import infer_device + + +@triton.jit +def _jsd_kernel( + X_ptr, # input in logspace, X = log Q + X_stride, + Y_ptr, # ground truth in logspace, Y = log P + Y_stride, + loss_ptr, + loss_stride, + dX_ptr, + dX_stride, + label_ptr, + beta: tl.constexpr, + n_non_ignore: int, + ignore_index: tl.constexpr, + n_cols, + BLOCK_SIZE: tl.constexpr, + HAS_LABEL: tl.constexpr, +): + # JSD(P || Q) = (KL(P || M) + KL(Q || M)) / 2, M = (1/2) * (P + Q) = (1/2) * (e ^ Y + e ^ X) + # = sum(P * log P + Q * log Q - 2 * M * log M) / 2 + # = sum(e ^ Y * Y + e ^ X * X - 2 * M * log M) / 2 + # grad_x_i = 0.5 * Q * (X - log_M) + pid = tl.program_id(0).to(tl.int64) + X_ptr += pid * X_stride + dX_ptr += pid * dX_stride + Y_ptr += pid * Y_stride + loss_ptr += pid * loss_stride + label_ptr += pid + + if HAS_LABEL: + label = tl.load(label_ptr) + if label == ignore_index: + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + tl.store(dX_ptr + offsets, 0.0, mask=offsets < n_cols) + return + + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32) + Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32) + + if beta == 0.0: # forward KL + Y_max = tl.max(Y, axis=0) + Y_shifted = Y - Y_max + Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift + loss = Y_prob * (Y - X) + dX = -Y_prob + elif beta == 1.0: # reverse KL + X_max = tl.max(X, axis=0) + X_shifted = X - X_max + X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift + loss = X_prob * (X - Y) + dX = loss + X_prob + else: + max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0)) + X_shifted = X - max_val + Y_shifted = Y - max_val + + # Pre-compute exp(max_val) since it's used twice + exp_max = tl.exp(max_val) + + # Compute exp terms with compensation + Q = tl.exp(X_shifted) * exp_max # = exp(X) + P = tl.exp(Y_shifted) * exp_max # = exp(Y) + + # Pre-compute common terms + beta_P = beta * P + one_minus_beta_Q = (1 - beta) * Q + M = beta_P + one_minus_beta_Q + log_M = tl.log(M) # No need to compensate as M is already in original scale + + loss = beta_P * Y + one_minus_beta_Q * X - M * log_M + dX = one_minus_beta_Q * (X - log_M) + + # Pre-compute scaling factor + scale = 1.0 / n_non_ignore + loss = loss * scale + dX = dX * scale + + tl.store(loss_ptr + offsets, loss, mask=mask) + tl.store(dX_ptr + offsets, dX, mask=mask) + + +MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 + + +def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label): + BT, V = _input.shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + # non reduction loss + loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device) + dX = torch.empty_like(_input) + + if has_label: + n_non_ignore = (shift_labels != ignore_index).sum().item() + else: + n_non_ignore = BT + + _jsd_kernel[(n_rows,)]( + X_ptr=_input, # input in logspace, X = log Q + X_stride=_input.stride(-2), + Y_ptr=target, # ground truth in logspace, Y = log P + Y_stride=target.stride(-2), + loss_ptr=loss, + loss_stride=loss.stride(-2), + dX_ptr=dX, + dX_stride=dX.stride(-2), + label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)), # dummy ptr if no label + beta=beta, + n_non_ignore=n_non_ignore, + ignore_index=ignore_index, + n_cols=V, + BLOCK_SIZE=BLOCK_SIZE, + HAS_LABEL=has_label, + ) + + loss = torch.sum(loss) + return loss.to(_input.dtype), dX + + +def jsd_backward(dX, grad_output): + # If jsd is the last layer, grad_output is 1.0. Skip the mul to save time + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + return dX + else: + return grad_output * dX + + +class LigerJSDFunction(torch.autograd.Function): + r""" + This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence. + .. math:: + JSD(\beta)(P || Q) + = \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q)) + + .. note:: + As all the other losses in PyTorch, this function expects the first argument, + :attr:`_input`, to be the predictions, the output of the student model, in log-space + and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space. + This differs from the standard mathematical notation :math:`JSD(P || Q)` where + :math:`P` denotes the teacher model and :math:`Q` denotes the student model. + """ + + @staticmethod + @ensure_contiguous + def forward( + ctx, + _input: torch.Tensor, + target: torch.Tensor, + shift_labels: Optional[torch.Tensor] = None, + beta: float = 0.5, + ignore_index: int = -100, + ) -> torch.Tensor: + """ + Args: + _input (torch.Tensor): predict values with shape (BT, V) in logspace + target (torch.Tensor): ground truth values with shape (BT, V) in logspace + shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1]. + beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5` + ignore_index (int): the index to ignore. Default: -100 + + Returns: + loss (torch.Tensor): generalized JSD + """ + has_label = False + if shift_labels is not None: + assert shift_labels.shape == (_input.shape[0],), ( + f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}" + ) + shift_labels = shift_labels.contiguous() + has_label = True + + loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label) + ctx.save_for_backward(dX) + return loss + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + (dX,) = ctx.saved_tensors + dX = jsd_backward(dX, grad_output) + return ( + dX, + None, + None, + None, + None, + ) diff --git a/src/liger_kernel/ops/kl_div.py b/src/liger_kernel/ops/kl_div.py new file mode 100755 index 0000000000000000000000000000000000000000..273294072633528d7e07bffa6b74086d85a13ff4 --- /dev/null +++ b/src/liger_kernel/ops/kl_div.py @@ -0,0 +1,259 @@ +from typing import Literal + +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.utils import ensure_contiguous +from liger_kernel.ops.utils import is_hip +from liger_kernel.utils import infer_device + + +def get_num_warps(BLOCK_SIZE): + num_warps = 4 + if BLOCK_SIZE >= 32768: + num_warps = 32 if not is_hip() else 16 + elif BLOCK_SIZE >= 8192: + num_warps = 16 + elif BLOCK_SIZE >= 2048: + num_warps = 8 + + return num_warps + + +if infer_device() == "xpu": + MAX_FUSED_SIZE = 8192 +elif infer_device() == "npu": + MAX_FUSED_SIZE = 8192 +else: + MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best + +REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"] + +_REDUCTION_MODE_NONE: tl.constexpr = tl.constexpr(0) +_REDUCTION_MODE_SUM: tl.constexpr = tl.constexpr(1) +_REDUCTION_MODE_MEAN: tl.constexpr = tl.constexpr(2) +_REDUCTION_MODE_BATCHMEAN: tl.constexpr = tl.constexpr(3) + +_str_to_reduction_mode = { + "none": _REDUCTION_MODE_NONE.value, + "sum": _REDUCTION_MODE_SUM.value, + "mean": _REDUCTION_MODE_MEAN.value, + "batchmean": _REDUCTION_MODE_BATCHMEAN.value, +} + + +@triton.jit +def _kldiv_kernel_forward( + y_ptr, # [B, S], prediction ptr, the kernel expects the prediction in log-space + y_stride, # int, prediction stride + gt_ptr, # [B, S], ground truth ptr + gt_stride, # int, ground truth stride + loss_ptr, # [B] or [B, S] if reduction == _REDUCTION_MODE_NONE, output ptr + loss_stride, # int, output stride + n_cols, # int, number of columns in the input tensor + eps, + BLOCK_SIZE: tl.constexpr, + log_target: tl.constexpr = False, + reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN, +): + pid = tl.program_id(0).to(tl.int64) + y_ptr += pid * y_stride + gt_ptr += pid * gt_stride + loss_ptr += pid * loss_stride + + base_offsets = tl.arange(0, BLOCK_SIZE) + + loss_sum = 0.0 + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + base_offsets + mask = offsets < n_cols + y = tl.load(y_ptr + offsets, mask=mask, other=0.0) + y_true = tl.load(gt_ptr + offsets, mask=mask, other=0.0) + + # KL(y_true || y) = y_true * (log(y_true) - log(y)) + # We compute KL(y_true || y) with y in the log-space + if not log_target: + loss = y_true * (tl.log(tl.maximum(y_true, eps)) - y) + else: + loss = tl.exp(y_true) * (y_true - y) + + if reduction == _REDUCTION_MODE_NONE: + tl.store(loss_ptr + offsets, loss, mask=mask) + else: + loss_sum += tl.sum(loss, axis=0) + + if reduction != _REDUCTION_MODE_NONE: + tl.store(loss_ptr, loss_sum) + + +@triton.jit +def _kldiv_kernel_backward( + target_ptr, + target_stride, + new_grads_ptr, + new_grads_stride, + n_cols, + BLOCK_SIZE: tl.constexpr, + log_target: tl.constexpr = False, +): + pid = tl.program_id(0).to(tl.int64) + + target_ptr += pid * target_stride + new_grads_ptr += pid * new_grads_stride + + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + + target = tl.load(target_ptr + offsets, mask=mask, other=0.0) + + if not log_target: + res = target * -1 + else: + res = -tl.exp(target) + + tl.store(new_grads_ptr + offsets, res, mask=mask) + + +def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V] + BT, V = y_pred.shape + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE) + + grid = (BT,) + reduction = _str_to_reduction_mode[reduction] + + out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,) + output_tensor = torch.zeros(out_size, device=y_pred.device, dtype=torch.float32) + + _kldiv_kernel_forward[grid]( + y_pred, + y_pred.stride(0), + y_true, + y_true.stride(0), + output_tensor, + output_tensor.stride(0), + V, + eps=eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + log_target=log_target, + reduction=reduction, + ) + + # calculated according to the reduction mode same as in Pytorch. In the later versions, `mean` will be changed to the same behavior as `batchmean` + # https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html + # https://github.com/pytorch/pytorch/blob/d7b57c4d63edb42e1deeeba9497fcb5f1f748ff2/torch/nn/functional.py#L3372 + if reduction == _REDUCTION_MODE_BATCHMEAN.value: + return output_tensor.sum() / BT + elif reduction == _REDUCTION_MODE_SUM.value: + return output_tensor.sum(dim=0) + elif reduction == _REDUCTION_MODE_MEAN.value: + return output_tensor.sum() / (BT * V) + else: + return output_tensor + + +def kldiv_backward_triton(target, grad_output, new_grads, log_target): + BT, V = target.shape + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE) + + grid = (BT,) + + # We store the gradients in-place in the input tensor + _kldiv_kernel_backward[grid]( + target, + target.stride(0), + new_grads, + new_grads.stride(0), + V, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + log_target=log_target, + ) + + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then. + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + return new_grads + + return new_grads * grad_output + + +class LigerKLDivLossFunction(torch.autograd.Function): + """ + Class implementing the forward and backward pass for the KL Divergence Loss using Triton, as defined by the following formula: + ```python + if log_target: + loss = target.exp() * (target - input) + else: + loss = target * (target.log() - input) + ```, + then the loss is reduced according to the `reduction` parameter. + as defined in the PyTorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html + """ + + @staticmethod + @ensure_contiguous + def forward( + ctx, + y_pred: torch.Tensor, + y_true: torch.Tensor, + reduction: REDUCTION_LITERAL = "batchmean", + log_target: bool = False, + eps: float = 1e-10, + ) -> torch.Tensor: + """A forward pass for the KL Divergence Loss. + + Args: + ctx: Torch autograd context + y_pred (torch.Tensor): A tensor of shape (BT, V) containing the predicted values, expected to be log-probabilities. + y_true (torch.Tensor): A tensor of shape (BT, V) containing the target values, expected to be either probabilities or log-probabilities, depending on the value of `log_target`. + reduction (REDUCTION_LITERAL, optional): Reduction to be used. Defaults to "batchmean". + log_target (bool, optional): If set to true, expects the ground truth to already be log-probabilities. Defaults to False. + eps: (float, optional): A small value to avoid division by zero. Defaults to 1e-10. + + Returns: + torch.Tensor: The computed KL Divergence Loss, with shape (BT, V) if `reduction` is "none", else a scalar. + """ + ctx.save_for_backward(y_true) + ctx.reduction = reduction + ctx.log_target = log_target + return kldiv_forward_triton(y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps) + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + """A backward pass for the KL Divergence Loss. + + Args: + ctx: Torch autograd context + grad_output (torch.Tensor): The gradient of the loss with respect to the output. + + Returns: + tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs and None for the other arguments of the forward method. + """ + (y_true,) = ctx.saved_tensors + + new_grads = torch.empty_like(y_true) + + derivative = kldiv_backward_triton(y_true, grad_output, new_grads, ctx.log_target) + + if ctx.reduction == "batchmean": + derivative = derivative / y_true.shape[0] + elif ctx.reduction == "sum" or ctx.reduction == "none": + pass + elif ctx.reduction == "mean": + derivative = derivative / (y_true.shape[0] * y_true.shape[1]) + + return ( + derivative, + None, + None, + None, + None, + ) diff --git a/src/liger_kernel/ops/layer_norm.py b/src/liger_kernel/ops/layer_norm.py new file mode 100755 index 0000000000000000000000000000000000000000..e8ac6b5f39b8ed227b68f75cc79f72fc4012b5c2 --- /dev/null +++ b/src/liger_kernel/ops/layer_norm.py @@ -0,0 +1,320 @@ +import math +import operator + +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.utils import calculate_settings +from liger_kernel.ops.utils import compare_version +from liger_kernel.ops.utils import ensure_contiguous +from liger_kernel.ops.utils import get_npu_core_count +from liger_kernel.ops.utils import set_large_grf_mode +from liger_kernel.utils import is_npu_available + +if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available(): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import rsqrt + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import rsqrt +else: + from triton.language.math import rsqrt + + +@triton.jit +def _layer_norm_forward_kernel( + Y_ptr, # pointer to output, shape (n_rows, n_cols) + Y_row_stride, # stride of each row in output + X_ptr, # pointer to input, shape (n_rows, n_cols) + X_row_stride, # stride of each row in input + W_ptr, # pointer to weights, shape (n_cols,) + W_row_stride, # stride of each row in weights + B_ptr, # pointer to bias, shape (n_cols,) + B_row_stride, # stride of each row in bias + Mean_ptr, # pointer to mean, shape (n_rows,) + Mean_row_stride, # stride of each row in mean + RSTD_ptr, # pointer to rstd, shape (n_rows,) + RSTD_row_stride, # stride of each row in rstd + n_cols, + eps, + BLOCK_SIZE: tl.constexpr, +): + """ + References: + https://arxiv.org/abs/1607.06450 + https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md + """ + row_idx = tl.program_id(0).to(tl.int64) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + # Pre-load weights and bias in fp32 to avoid repeated conversions + W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0) + B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0.0) + W_f32 = W_row.to(tl.float32) + B_f32 = B_row.to(tl.float32) + + # Calculate pointers for this row + row_X_ptr = X_ptr + row_idx * X_row_stride + row_Y_ptr = Y_ptr + row_idx * Y_row_stride + row_Mean_ptr = Mean_ptr + row_idx * Mean_row_stride + row_RSTD_ptr = RSTD_ptr + row_idx * RSTD_row_stride + + # Load input data and convert to fp32 for numerical stability + X_row = tl.load(row_X_ptr + col_offsets, mask=mask, other=0.0) + X_f32 = X_row.to(tl.float32) + + # Compute statistics in fp32 for numerical stability + mean = tl.sum(X_f32, axis=0) / n_cols + X_centered = X_f32 - mean + # Apply mask to variance calculation to exclude contributions from masked elements + X_centered_masked = tl.where(mask, X_centered, 0.0) + var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols + rstd = rsqrt(var + eps) + + # Store statistics (convert back to original dtype only once) + tl.store(row_Mean_ptr, mean.to(X_row.dtype)) + tl.store(row_RSTD_ptr, rstd.to(X_row.dtype)) + + # Fused normalization and affine transformation + # Y = (X - mean) * rstd * W + B = X_centered * rstd * W + B + Y_f32 = X_centered * rstd * W_f32 + B_f32 + + # Store output (single conversion back to original dtype) + tl.store(row_Y_ptr + col_offsets, Y_f32.to(X_row.dtype), mask=mask) + + +@triton.jit +def _layer_norm_backward_kernel( + X_ptr, # pointer to input, shape (n_rows, n_cols) + stride_x, # stride of each row in input + W_ptr, # pointer to weights, shape (n_cols,) + Mean_ptr, # pointer to mean, shape (n_rows,) + stride_mean, # stride of each row in mean + RSTD_ptr, # pointer to rstd, shape (n_rows,) + stride_rstd, # stride of each row in rstd + DX_ptr, # pointer to input grad, shape (n_rows, n_cols) + stride_dx, # stride of each row in input grad + DW_ptr, # pointer to weights grad, shape (n_cols,) + stride_dw, # stride of each row in weights grad + DB_ptr, # pointer to bias grad, shape (n_cols,) + stride_db, # stride of each row in bias grad + DY_ptr, # pointer to output grad, shape (n_rows, n_cols) + stride_dy, # stride of each row in output grad + n_rows, + n_cols, + rows_per_program: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + References: + https://arxiv.org/abs/1607.06450 + https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md + """ + row_block_id = tl.program_id(0).to(tl.int64) + row_start = row_block_id * rows_per_program + row_end = min((row_block_id + 1) * rows_per_program, n_rows) + cols = tl.arange(0, BLOCK_SIZE) + mask = cols < n_cols + + dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + + # Pre-load weights once (same optimization as forward pass) + w = tl.load(W_ptr + cols, mask=mask, other=0.0) + w_f32 = w.to(tl.float32) + + for row_idx in range(row_start, row_end): + # Calculate pointers for this specific row + row_X_ptr = X_ptr + row_idx * stride_x + row_DX_ptr = DX_ptr + row_idx * stride_dx + row_DY_ptr = DY_ptr + row_idx * stride_dy + row_Mean_ptr = Mean_ptr + row_idx * stride_mean + row_RSTD_ptr = RSTD_ptr + row_idx * stride_rstd + + # Load data for this row + x = tl.load(row_X_ptr + cols, mask=mask, other=0.0) + dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0) + mean = tl.load(row_Mean_ptr) + rstd = tl.load(row_RSTD_ptr) + + # Convert to fp32 for numerical stability + x_f32 = x.to(tl.float32) + dy_f32 = dy.to(tl.float32) + mean_f32 = mean.to(tl.float32) + rstd_f32 = rstd.to(tl.float32) + + # Compute backward pass for this row + x_hat = (x_f32 - mean_f32) * rstd_f32 + wdy = w_f32 * dy_f32 + c1 = tl.sum(x_hat * wdy, axis=0) / n_cols + c2 = tl.sum(wdy, axis=0) / n_cols + dx = (wdy - (x_hat * c1 + c2)) * rstd_f32 + + # Store input gradient + tl.store(row_DX_ptr + cols, dx, mask=mask) + + # Accumulate weight and bias gradients for this thread block's assigned rows + dw = dy_f32 * x_hat + db = dy_f32 + dW_row += dw + db_row += db + + tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask) + tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask) + + +def layer_norm_forward(X, W, B, eps): + """ + Args: + X: Input tensor of shape (..., hidden_size) + W: Weight tensor of shape (hidden_size,) + B: Bias tensor of shape (hidden_size,) + eps: Small constant for numerical stability + + Returns: + Tuple of (output, input, mean, rstd, block_size, num_warps) + """ + shape = X.shape + dim = shape[-1] + X = X.view(-1, dim) + n_rows, n_cols = X.shape + + # Calculate optimal block size and warp configuration + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + # Allocate output tensors + Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) + Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device) + RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device) + + # Validate input dimensions + if X.shape[1] != W.shape[0]: + raise ValueError( + f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) " + f"must match weight size (W.shape[0]={W.shape[0]})" + ) + + # XPU-specific optimization + kernel_args = {} + if X.device.type == "xpu": + set_large_grf_mode(kernel_args) + + # Launch kernel with one thread block per row for optimal performance + grid = (n_rows,) + _layer_norm_forward_kernel[grid]( + Y, + Y.stride(0), + X, + X.stride(0), + W, + W.stride(0), + B, + B.stride(0), + Mean, + Mean.stride(0), + RSTD, + RSTD.stride(0), + n_cols, + eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + **kernel_args, + ) + + return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps + + +def layer_norm_backward(dY, X, W, B, Mean, RSTD): + """ + Args: + dY: Gradient of output + X: Input tensor + W: Weight tensor + B: Bias tensor + Mean: Pre-computed mean + RSTD: Pre-computed reciprocal standard deviation + + Returns: + Tuple of (input_grad, weight_grad, bias_grad) + """ + shape = dY.shape + dim = shape[-1] + dY = dY.view(-1, dim) + n_rows, n_cols = dY.shape + + sm_count = 1 + if X.device.type == "cuda": + sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count + elif X.device.type == "xpu": + sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count + elif X.device.type == "npu": + sm_count = get_npu_core_count() + + # fp32 for numerical stability especially. + _DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device) + _DB = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device) + + # Calculate optimal block size and warp configuration + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + if n_cols > BLOCK_SIZE: + raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.") + rows_per_program = math.ceil(n_rows / sm_count) + grid = (sm_count,) + + # Allocate gradient tensors + DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) + + kernel_args = {"num_warps": num_warps} + # XPU-specific optimization + if X.device.type == "xpu": + kernel_args.update({"num_warps": 32, "num_stages": 4}) + set_large_grf_mode(kernel_args) + + # Launch kernel with one thread block per row for optimal performance + _layer_norm_backward_kernel[grid]( + X, + X.stride(0), + W, + Mean, + Mean.stride(0), + RSTD, + RSTD.stride(0), + DX, + DX.stride(0), + _DW, + _DW.stride(0), + _DB, + _DB.stride(0), + dY, + dY.stride(0), + n_rows, + n_cols, + rows_per_program=rows_per_program, + BLOCK_SIZE=BLOCK_SIZE, + **kernel_args, + ) + + DX = DX.view(*shape) + DW = _DW.sum(dim=0).to(W.dtype) + DB = _DB.sum(dim=0).to(B.dtype) + + return DX, DW, DB + + +class LigerLayerNormFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, X, W, B, eps): + Y, X, Mean, RSTD, BLOCK_SIZE, num_warps = layer_norm_forward(X, W, B, eps) + ctx.save_for_backward(X, W, B, Mean, RSTD) + return Y + + @staticmethod + @ensure_contiguous + def backward(ctx, dY): + X, W, B, Mean, RSTD = ctx.saved_tensors + DX, DW, DB = layer_norm_backward(dY, X, W, B, Mean, RSTD) + return DX, DW, DB, None diff --git a/src/liger_kernel/ops/llama4_rope.py b/src/liger_kernel/ops/llama4_rope.py new file mode 100755 index 0000000000000000000000000000000000000000..9167d69f35d76998367a7a5bb095c130aa8c39f1 --- /dev/null +++ b/src/liger_kernel/ops/llama4_rope.py @@ -0,0 +1,180 @@ +import torch +import triton +import triton.language as tl + + +def _cast_and_contiguous(q, k, freqs_complex): + # Align dtype: fp32 only when q is fp32; otherwise keep q dtype for perf + compute_dtype = torch.float32 if q.dtype == torch.float32 else q.dtype + + if k.dtype != q.dtype: + k = k.to(q.dtype) + + q = q.to(compute_dtype).contiguous() + k = k.to(compute_dtype).contiguous() + freqs_complex = freqs_complex.contiguous() + return q, k, freqs_complex + + +@triton.jit +def _llama4_rope_kernel( + q_ptr, + k_ptr, + freqs_complex_ptr, + q_row_stride, + k_row_stride, + q_head_stride, + k_head_stride, + freqs_row_stride, + seq_len, + batch_size, + imag_sign, + head_dim_half: tl.constexpr, + n_q_heads: tl.constexpr, + n_k_heads: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + H100-optimized RoPE kernel with improved parallelization across heads and dimensions. + Grid: (batch*seq, head) + """ + # 2D grid + pid_bs = tl.program_id(0) # over batch*seq + pid_h = tl.program_id(1) # over heads + + batch_idx = pid_bs // seq_len + seq_idx = pid_bs % seq_len + + # Bounds check + if batch_idx >= batch_size or seq_idx >= seq_len: + return + + # Base pointers for this (batch, seq) position + base_offset = batch_idx * seq_len + seq_idx + q_base = q_ptr + base_offset * q_row_stride + k_base = k_ptr + base_offset * k_row_stride + freq_base = seq_idx * freqs_row_stride + + # Tiling over dim/2 + for d_start in tl.static_range(0, head_dim_half, BLOCK_SIZE): + d_indices = d_start + tl.arange(0, BLOCK_SIZE) + mask_d = d_indices < head_dim_half + + # Compute offsets for the block + freq_offsets = d_indices[:, None] * 2 + tl.arange(0, 2)[None, :] + # Load the block + freqs_complex = tl.load(freqs_complex_ptr + freq_base + freq_offsets, mask=mask_d[:, None], other=0.0) + freqs_real, freqs_imag = tl.split(freqs_complex) + freqs_imag = freqs_imag * imag_sign + + # Process one query head per program in pid_h + if pid_h < n_q_heads: + q_head_ptr = q_base + pid_h * q_head_stride + q_real = tl.load(q_head_ptr + d_indices * 2, mask=mask_d, other=0.0) + q_imag = tl.load(q_head_ptr + d_indices * 2 + 1, mask=mask_d, other=0.0) + + # Complex multiply with FMAs: (a+ib)*(c+i d) = (a*c - b*d) + i(a*d + b*c) + new_q_real = tl.math.fma(q_real, freqs_real, -(q_imag * freqs_imag)) + new_q_imag = tl.math.fma(q_real, freqs_imag, q_imag * freqs_real) + + tl.store(q_head_ptr + d_indices * 2, new_q_real, mask=mask_d) + tl.store(q_head_ptr + d_indices * 2 + 1, new_q_imag, mask=mask_d) + + # Process one key head per program in pid_h + if pid_h < n_k_heads: + k_head_ptr = k_base + pid_h * k_head_stride + k_real = tl.load(k_head_ptr + d_indices * 2, mask=mask_d, other=0.0) + k_imag = tl.load(k_head_ptr + d_indices * 2 + 1, mask=mask_d, other=0.0) + + new_k_real = tl.math.fma(k_real, freqs_real, -(k_imag * freqs_imag)) + new_k_imag = tl.math.fma(k_real, freqs_imag, k_imag * freqs_real) + + tl.store(k_head_ptr + d_indices * 2, new_k_real, mask=mask_d) + tl.store(k_head_ptr + d_indices * 2 + 1, new_k_imag, mask=mask_d) + + +def _select_kernel_meta(head_dim_half: int): + # Heuristic tuning for block size and num_warps + if head_dim_half >= 256: + return 128, 8 + if head_dim_half >= 96: + return 128, 4 + if head_dim_half >= 48: + return 64, 4 + if head_dim_half >= 24: + return 32, 2 + return 16, 2 + + +def llama4_rope_forward(q, k, freqs_cis, BLOCK_SIZE: int = None, imag_sign: float = 1.0): + # Save original dtype for casting back + original_dtype = q.dtype + + batch_size, seq_len, n_q_heads, head_dim = q.shape + _, _, n_k_heads, _ = k.shape + head_dim_half = head_dim // 2 + if freqs_cis.is_complex(): + freqs_cis = freqs_cis.reshape(-1, freqs_cis.shape[-1]) + if freqs_cis.shape[0] > seq_len: + freqs_cis = freqs_cis[:seq_len] + freqs_cis = torch.view_as_real(freqs_cis) + + # Cast to appropriate dtype and make contiguous only when needed + q, k, freqs_cis = _cast_and_contiguous(q, k, freqs_cis) + + # H100-optimized meta-params + if BLOCK_SIZE is None: + BLOCK_SIZE, num_warps = _select_kernel_meta(head_dim_half) + else: + # Provide a default num_warps if caller pins BLOCK_SIZE + _, num_warps = _select_kernel_meta(head_dim_half) + + # 2D grid: one program per (batch, seq, head) + n_heads_max = max(n_q_heads, n_k_heads) + grid = (batch_size * seq_len, n_heads_max) + + # Launch kernel + _llama4_rope_kernel[grid]( + q, + k, + freqs_cis, + q.stride(1), + k.stride(1), + q.stride(2), + k.stride(2), + freqs_cis.stride(0), + seq_len, + batch_size, + imag_sign, + head_dim_half, + n_q_heads, + n_k_heads, + BLOCK_SIZE, + num_warps=num_warps, + num_stages=2, + ) + + # Cast back to original dtype only if it differs from compute dtype + if q.dtype != original_dtype: + q = q.to(original_dtype) + if k.dtype != original_dtype: + k = k.to(original_dtype) + + return q, k + + +class LigerLlama4RopeFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, freqs_cis, BLOCK_SIZE: int = None): + q_out, k_out = llama4_rope_forward(q, k, freqs_cis, BLOCK_SIZE, imag_sign=1.0) + ctx.save_for_backward(freqs_cis.detach() if isinstance(freqs_cis, torch.Tensor) else freqs_cis) + ctx.BLOCK_SIZE = BLOCK_SIZE + return q_out, k_out + + @staticmethod + def backward(ctx, dq, dk): + (freqs_cis,) = ctx.saved_tensors + BLOCK_SIZE = getattr(ctx, "BLOCK_SIZE", None) + # Use imag_sign=-1.0 for conjugate without materializing a new tensor + dq_out, dk_out = llama4_rope_forward(dq, dk, freqs_cis, BLOCK_SIZE, imag_sign=-1.0) + return dq_out, dk_out, None diff --git a/src/liger_kernel/ops/mhc.py b/src/liger_kernel/ops/mhc.py new file mode 100755 index 0000000000000000000000000000000000000000..1a4569d334137a3102070e24543655df361cd6b6 --- /dev/null +++ b/src/liger_kernel/ops/mhc.py @@ -0,0 +1,1674 @@ +import math + +from typing import Any +from typing import Optional +from typing import Tuple +from typing import Union + +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.utils import ensure_contiguous + + +def _post_res_default_meta(c: int) -> Tuple[int, int, int, int]: + """ + Returns default (block_n, block_c, num_warps, num_stages) for post_res kernels. + Tuned for different hidden dimensions on NVIDIA GPUs. + """ + if c >= 4096: + return 32, 128, 8, 3 # (block_n, block_c, num_warps, num_stages) + if c >= 2048: + return 32, 128, 4, 2 + if c >= 1024: + return 32, 64, 4, 2 + return 32, 64, 2, 2 + + +def _post_res_meta( + c: int, + block_n: Optional[int], + block_c: Optional[int], + num_warps: Optional[int], + num_stages: Optional[int], +) -> Tuple[int, int, int, int]: + bn, bc, nw, ns = _post_res_default_meta(c) + return ( + bn if block_n is None else int(block_n), + bc if block_c is None else int(block_c), + nw if num_warps is None else int(num_warps), + ns if num_stages is None else int(num_stages), + ) + + +# ------------------------------------------------------------------------------------------------- +# (1) Coefficients: fused matmul + RMS scalar (Eq. 14–15) +# mix = (x @ phi) * rsqrt(mean(x^2) + eps) +# +# We provide two paths: +# - TC path: x BF16/FP16 and phi BF16/FP16 (Tensor Cores) +# - TF32-ish path: x cast to FP32 and phi FP32 (relies on Triton/arch for TF32) +# ------------------------------------------------------------------------------------------------- + + +@triton.jit +def _mhc_mm_norm_fwd_kernel( + x_ptr, + phi_ptr, + mix_ptr, + invr_ptr, + N: tl.constexpr, + K: tl.constexpr, + M: tl.constexpr, + stride_xn: tl.constexpr, + stride_xk: tl.constexpr, + stride_phik: tl.constexpr, + stride_phim: tl.constexpr, + stride_mn: tl.constexpr, + stride_mm: tl.constexpr, + eps: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_M: tl.constexpr, + CAST_FP32: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_m = tl.program_id(1) + + n_offs = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + m_offs = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + + acc = tl.zeros((BLOCK_N, BLOCK_M), tl.float32) + sumsq = tl.zeros((BLOCK_N,), tl.float32) + + for k0 in tl.static_range(0, K, BLOCK_K): + k_offs = k0 + tl.arange(0, BLOCK_K) + + x = tl.load( + x_ptr + n_offs[:, None] * stride_xn + k_offs[None, :] * stride_xk, + mask=(n_offs[:, None] < N) & (k_offs[None, :] < K), + other=0.0, + ) + if CAST_FP32: + x = x.to(tl.float32) + sumsq += tl.sum(x * x, axis=1) + else: + x_f = x.to(tl.float32) + sumsq += tl.sum(x_f * x_f, axis=1) + + phi = tl.load( + phi_ptr + k_offs[:, None] * stride_phik + m_offs[None, :] * stride_phim, + mask=(k_offs[:, None] < K) & (m_offs[None, :] < M), + other=0.0, + ) + if CAST_FP32: + phi = phi.to(tl.float32) + + acc += tl.dot(x, phi) + + invr = tl.rsqrt(sumsq / K + eps) + out = acc * invr[:, None] + + tl.store( + mix_ptr + n_offs[:, None] * stride_mn + m_offs[None, :] * stride_mm, + out, + mask=(n_offs[:, None] < N) & (m_offs[None, :] < M), + ) + if pid_m == 0: + tl.store(invr_ptr + n_offs, invr, mask=n_offs < N) + + +def mhc_mm_norm_fwd( + x: torch.Tensor, + phi: torch.Tensor, + eps: float, + *, + out_mix: Optional[torch.Tensor] = None, + out_invr: Optional[torch.Tensor] = None, + block_n: int = 32, + block_k: int = 256, + block_m: int = 32, + num_warps: int = 4, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Fused (x @ phi) + invr = rsqrt(mean(x^2)+eps) and returns mix=(x@phi)*invr. + + Args: + x: [N, K] contiguous + phi: [K, M] contiguous + eps: float + Returns: + mix: [N, M] float32 + invr: [N] float32 + """ + assert x.is_contiguous(), "x must be contiguous" + assert phi.is_contiguous(), "phi must be contiguous" + + N, K = x.shape + K2, M = phi.shape + assert K2 == K, f"phi.shape[0] must match K: got {K2} vs {K}" + + if out_mix is None: + out_mix = torch.empty((N, M), device=x.device, dtype=torch.float32) + if out_invr is None: + out_invr = torch.empty((N,), device=x.device, dtype=torch.float32) + + grid = (triton.cdiv(N, block_n), triton.cdiv(M, block_m)) + + use_tc = (x.dtype == phi.dtype) and (x.dtype in (torch.float16, torch.bfloat16)) + + _mhc_mm_norm_fwd_kernel[grid]( + x, + phi, + out_mix, + out_invr, + N=N, + K=K, + M=M, + stride_xn=x.stride(0), + stride_xk=x.stride(1), + stride_phik=phi.stride(0), + stride_phim=phi.stride(1), + stride_mn=out_mix.stride(0), + stride_mm=out_mix.stride(1), + eps=eps, + BLOCK_N=block_n, + BLOCK_K=block_k, + BLOCK_M=block_m, + CAST_FP32=not use_tc, + num_warps=num_warps, + ) + return out_mix, out_invr + + +# ------------------------------------------------------------------------------------------------- +# Backward for fused (x @ phi) + RMS scalar +# +# mix = (x @ phi) * invr +# invr = rsqrt(mean(x^2) + eps) +# +# Given grad_mix, compute: +# grad_z = grad_mix * invr +# g = sum(grad_mix * (mix / invr)) = sum(grad_mix * mix) / invr +# factor = -(g / K) * invr^3 +# grad_x = grad_z @ phi^T + factor * x +# grad_phi = x^T @ grad_z +# +# grad_phi is accumulated into FP32 with atomic adds (split over N-chunks). +# ------------------------------------------------------------------------------------------------- + + +@triton.jit +def _mhc_mm_norm_bwd_fused_kernel( + x_ptr, + phi_ptr, + mix_ptr, + invr_ptr, + grad_mix_ptr, + grad_x_ptr, + grad_phi_ptr, + N: tl.constexpr, + K: tl.constexpr, + M: tl.constexpr, + stride_xn: tl.constexpr, + stride_xk: tl.constexpr, + stride_phik: tl.constexpr, + stride_phim: tl.constexpr, + stride_mn: tl.constexpr, + stride_mm: tl.constexpr, + stride_invr: tl.constexpr, + stride_gmn: tl.constexpr, + stride_gmm: tl.constexpr, + stride_gxn: tl.constexpr, + stride_gxk: tl.constexpr, + stride_gpk: tl.constexpr, + stride_gpm: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_M: tl.constexpr, + CAST_FP32: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_k = tl.program_id(1) + + n_offs = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + k_offs = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) + + invr = tl.load(invr_ptr + n_offs * stride_invr, mask=n_offs < N, other=0.0).to(tl.float32) + + x = tl.load( + x_ptr + n_offs[:, None] * stride_xn + k_offs[None, :] * stride_xk, + mask=(n_offs[:, None] < N) & (k_offs[None, :] < K), + other=0.0, + ) + if CAST_FP32: + x = x.to(tl.float32) + x_f = x + else: + x_f = x.to(tl.float32) + + acc = tl.zeros((BLOCK_N, BLOCK_K), tl.float32) + g_acc = tl.zeros((BLOCK_N,), tl.float32) + + for m0 in tl.static_range(0, M, BLOCK_M): + m_offs = m0 + tl.arange(0, BLOCK_M) + + grad_mix = tl.load( + grad_mix_ptr + n_offs[:, None] * stride_gmn + m_offs[None, :] * stride_gmm, + mask=(n_offs[:, None] < N) & (m_offs[None, :] < M), + other=0.0, + ).to(tl.float32) + + mix = tl.load( + mix_ptr + n_offs[:, None] * stride_mn + m_offs[None, :] * stride_mm, + mask=(n_offs[:, None] < N) & (m_offs[None, :] < M), + other=0.0, + ).to(tl.float32) + + g_acc += tl.sum(grad_mix * mix, axis=1) + + phi = tl.load( + phi_ptr + k_offs[:, None] * stride_phik + m_offs[None, :] * stride_phim, + mask=(k_offs[:, None] < K) & (m_offs[None, :] < M), + other=0.0, + ) + if CAST_FP32: + phi = phi.to(tl.float32) + grad_z = grad_mix * invr[:, None] + else: + grad_z = (grad_mix * invr[:, None]).to(phi.dtype) + + acc += tl.dot(grad_z, tl.trans(phi)) + + dphi = tl.dot(tl.trans(x), grad_z) + tl.atomic_add( + grad_phi_ptr + k_offs[:, None] * stride_gpk + m_offs[None, :] * stride_gpm, + dphi, + mask=(k_offs[:, None] < K) & (m_offs[None, :] < M), + ) + + g = g_acc / invr + invr3 = invr * invr * invr + factor = (-g * invr3) / K + + gx = acc + x_f * factor[:, None] + + if CAST_FP32: + tl.store( + grad_x_ptr + n_offs[:, None] * stride_gxn + k_offs[None, :] * stride_gxk, + gx, + mask=(n_offs[:, None] < N) & (k_offs[None, :] < K), + ) + else: + tl.store( + grad_x_ptr + n_offs[:, None] * stride_gxn + k_offs[None, :] * stride_gxk, + gx.to(x.dtype), + mask=(n_offs[:, None] < N) & (k_offs[None, :] < K), + ) + + +def mhc_mm_norm_bwd( + x: torch.Tensor, + phi: torch.Tensor, + mix: torch.Tensor, + invr: torch.Tensor, + grad_mix: torch.Tensor, + *, + out_grad_x: Optional[torch.Tensor] = None, + out_grad_phi: Optional[torch.Tensor] = None, + block_n: int = 32, + block_k: int = 256, + block_m: int = 32, + num_warps: int = 4, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Triton backward for `mhc_mm_norm_fwd`. + + Returns: + grad_x: [N, K] same dtype as x + grad_phi: [K, M] FP32 (safe for atomic adds; cast on return if needed) + + Note: + grad_phi is accumulated via atomic_add in FP32. For very large N + (batch * sequence length > 1M), accumulated rounding errors may + become noticeable. This is typically not an issue for standard + training configurations. + """ + assert ( + x.is_contiguous() + and phi.is_contiguous() + and mix.is_contiguous() + and invr.is_contiguous() + and grad_mix.is_contiguous() + ) + + N, K = x.shape + K2, M = phi.shape + assert K2 == K + assert mix.shape == (N, M) + assert grad_mix.shape == (N, M) + assert invr.shape == (N,) + + if out_grad_x is None: + out_grad_x = torch.empty_like(x) + if out_grad_phi is None: + out_grad_phi = torch.zeros((K, M), device=x.device, dtype=torch.float32) + + use_tc = (x.dtype == phi.dtype) and (x.dtype in (torch.float16, torch.bfloat16)) + + grid = (triton.cdiv(N, block_n), triton.cdiv(K, block_k)) + _mhc_mm_norm_bwd_fused_kernel[grid]( + x, + phi, + mix, + invr, + grad_mix, + out_grad_x, + out_grad_phi, + N=N, + K=K, + M=M, + stride_xn=x.stride(0), + stride_xk=x.stride(1), + stride_phik=phi.stride(0), + stride_phim=phi.stride(1), + stride_mn=mix.stride(0), + stride_mm=mix.stride(1), + stride_invr=invr.stride(0), + stride_gmn=grad_mix.stride(0), + stride_gmm=grad_mix.stride(1), + stride_gxn=out_grad_x.stride(0), + stride_gxk=out_grad_x.stride(1), + stride_gpk=out_grad_phi.stride(0), + stride_gpm=out_grad_phi.stride(1), + BLOCK_N=block_n, + BLOCK_K=block_k, + BLOCK_M=block_m, + CAST_FP32=not use_tc, + num_warps=num_warps, + ) + + if out_grad_phi.dtype != phi.dtype: + out_grad_phi = out_grad_phi.to(phi.dtype) + return out_grad_x, out_grad_phi + + +# ------------------------------------------------------------------------------------------------- +# Sinkhorn-Knopp forward/backward for H_res (Eq. 19) +# ------------------------------------------------------------------------------------------------- + + +@triton.jit +def _mhc_split_sinkhorn_fwd_kernel( + mix_ptr, + b_ptr, + hpre_ptr, + hpost_ptr, + hres_ptr, + hist_ptr, + N: tl.constexpr, + HC: tl.constexpr, + M: tl.constexpr, + stride_mn: tl.constexpr, + stride_mm: tl.constexpr, + stride_hp_n: tl.constexpr, + stride_hp_h: tl.constexpr, + stride_hq_n: tl.constexpr, + stride_hq_h: tl.constexpr, + stride_hr_n: tl.constexpr, + stride_hr_i: tl.constexpr, + stride_hr_j: tl.constexpr, + stride_hn: tl.constexpr, + stride_ht: tl.constexpr, + stride_hi: tl.constexpr, + stride_hj: tl.constexpr, + alpha_pre_ptr, + alpha_post_ptr, + alpha_res_ptr, + pre_eps: tl.constexpr, + sinkhorn_eps: tl.constexpr, + post_mult: tl.constexpr, + TMAX: tl.constexpr, + STORE_HIST: tl.constexpr, +): + pid = tl.program_id(0) + if pid >= N: + return + + # Load scalar alpha parameters from GPU memory (avoids CPU sync) + alpha_pre = tl.load(alpha_pre_ptr).to(tl.float32) + alpha_post = tl.load(alpha_post_ptr).to(tl.float32) + alpha_res = tl.load(alpha_res_ptr).to(tl.float32) + + # --- Pre/post logits + j = tl.arange(0, HC) + mix_pre = tl.load(mix_ptr + pid * stride_mn + j * stride_mm).to(tl.float32) + mix_post = tl.load(mix_ptr + pid * stride_mn + (HC + j) * stride_mm).to(tl.float32) + + b_pre = tl.load(b_ptr + j).to(tl.float32) + b_post = tl.load(b_ptr + (HC + j)).to(tl.float32) + + pre_logits = mix_pre * alpha_pre + b_pre + post_logits = mix_post * alpha_post + b_post + + pre = tl.sigmoid(pre_logits) + pre_eps + post = tl.sigmoid(post_logits) * post_mult + + tl.store(hpre_ptr + pid * stride_hp_n + j * stride_hp_h, pre) + tl.store(hpost_ptr + pid * stride_hq_n + j * stride_hq_h, post) + + # --- Residual logits matrix [HC, HC] + rows = tl.arange(0, HC)[:, None] + cols = tl.arange(0, HC)[None, :] + flat = rows * HC + cols # [HC,HC] + + mix_res = tl.load(mix_ptr + pid * stride_mn + (2 * HC + flat) * stride_mm).to(tl.float32) + b_res = tl.load(b_ptr + (2 * HC + flat)).to(tl.float32) + + logits = mix_res * alpha_res + b_res + + # Sinkhorn: initial row-softmax (stable) then alternating row/col norms. + row_max = tl.max(logits, axis=1) + e = tl.exp(logits - row_max[:, None]) + row_sum = tl.sum(e, axis=1) + mat = e / row_sum[:, None] + sinkhorn_eps + + col_sum = tl.sum(mat, axis=0) + mat = mat / (col_sum[None, :] + sinkhorn_eps) + + if STORE_HIST: + tl.store( + hist_ptr + pid * stride_hn + 0 * stride_ht + rows * stride_hi + cols * stride_hj, + mat, + ) + + for t in tl.static_range(0, TMAX - 1): + row_sum = tl.sum(mat, axis=1) + mat = mat / (row_sum[:, None] + sinkhorn_eps) + col_sum = tl.sum(mat, axis=0) + mat = mat / (col_sum[None, :] + sinkhorn_eps) + if STORE_HIST: + tl.store( + hist_ptr + pid * stride_hn + (t + 1) * stride_ht + rows * stride_hi + cols * stride_hj, + mat, + ) + + # Store h_res [N, HC, HC] (row-major: out, in) + tl.store(hres_ptr + pid * stride_hr_n + rows * stride_hr_i + cols * stride_hr_j, mat) + + +@triton.jit +def _mhc_sinkhorn_bwd_kernel( + mix_ptr, + b_ptr, + grad_out_ptr, + grad_logits_ptr, + N: tl.constexpr, + HC: tl.constexpr, + stride_mn: tl.constexpr, + stride_mm: tl.constexpr, + stride_go_n: tl.constexpr, + stride_go_i: tl.constexpr, + stride_go_j: tl.constexpr, + stride_gl_n: tl.constexpr, + stride_gl_i: tl.constexpr, + stride_gl_j: tl.constexpr, + alpha_res_ptr, + sinkhorn_eps: tl.constexpr, + TMAX: tl.constexpr, +): + pid = tl.program_id(0) + if pid >= N: + return + + alpha_res = tl.load(alpha_res_ptr).to(tl.float32) + + rows = tl.arange(0, HC)[:, None] + cols = tl.arange(0, HC)[None, :] + flat = rows * HC + cols + + # Rebuild logits + mix_res = tl.load(mix_ptr + pid * stride_mn + (2 * HC + flat) * stride_mm).to(tl.float32) + b_res = tl.load(b_ptr + (2 * HC + flat)).to(tl.float32) + logits = mix_res * alpha_res + b_res + + # Forward recompute (no lists) and backward with recompute per step. + row_max = tl.max(logits, axis=1) + e = tl.exp(logits - row_max[:, None]) + row_sum0 = tl.sum(e, axis=1) + p = e / row_sum0[:, None] # softmax, row-wise + p_eps = p + sinkhorn_eps + + col_sum0 = tl.sum(p_eps, axis=0) + mat0 = p_eps / (col_sum0[None, :] + sinkhorn_eps) + + # Start backward from grad_out + g = tl.load( + grad_out_ptr + pid * stride_go_n + rows * stride_go_i + cols * stride_go_j, + ).to(tl.float32) + + # Reverse iterations (TMAX-1 .. 1), recomputing mat_t, rs_t, cs_t + for t in tl.static_range(TMAX - 1, 0, -1): + mat = mat0 + rs_t = row_sum0 + cs_t = col_sum0 + mat_t = mat0 + + for s in tl.static_range(1, TMAX): + rs = tl.sum(mat, axis=1) + mat = mat / (rs[:, None] + sinkhorn_eps) + cs = tl.sum(mat, axis=0) + mat = mat / (cs[None, :] + sinkhorn_eps) + if s == t: + mat_t = mat + rs_t = rs + cs_t = cs + + denom_col = cs_t + sinkhorn_eps # [HC] + dot_col = tl.sum(g * mat_t, axis=0) # [HC] + g_row = (g - dot_col[None, :]) / denom_col[None, :] + + m_row = mat_t * denom_col[None, :] # invert col norm: m_row = m_out * denom + denom_row = rs_t + sinkhorn_eps + dot_row = tl.sum(g_row * m_row, axis=1) + g = (g_row - dot_row[:, None]) / denom_row[:, None] + + # Undo initial col norm (t=0) + denom_col0 = col_sum0 + sinkhorn_eps + dot_col0 = tl.sum(g * mat0, axis=0) + g_p = (g - dot_col0[None, :]) / denom_col0[None, :] + + # Softmax backward on rows: p * (g_p - sum(g_p * p)) + dot_soft = tl.sum(g_p * p, axis=1) + grad_logits = p * (g_p - dot_soft[:, None]) + + tl.store(grad_logits_ptr + pid * stride_gl_n + rows * stride_gl_i + cols * stride_gl_j, grad_logits) + + +@triton.jit +def _mhc_sinkhorn_bwd_hist_kernel( + mix_ptr, + b_ptr, + hist_ptr, + grad_out_ptr, + grad_logits_ptr, + N: tl.constexpr, + HC: tl.constexpr, + stride_mn: tl.constexpr, + stride_mm: tl.constexpr, + stride_hn: tl.constexpr, + stride_ht: tl.constexpr, + stride_hi: tl.constexpr, + stride_hj: tl.constexpr, + stride_go_n: tl.constexpr, + stride_go_i: tl.constexpr, + stride_go_j: tl.constexpr, + stride_gl_n: tl.constexpr, + stride_gl_i: tl.constexpr, + stride_gl_j: tl.constexpr, + alpha_res_ptr, + sinkhorn_eps: tl.constexpr, + TMAX: tl.constexpr, +): + pid = tl.program_id(0) + if pid >= N: + return + + alpha_res = tl.load(alpha_res_ptr).to(tl.float32) + + rows = tl.arange(0, HC)[:, None] + cols = tl.arange(0, HC)[None, :] + flat = rows * HC + cols + + # Rebuild logits + mix_res = tl.load(mix_ptr + pid * stride_mn + (2 * HC + flat) * stride_mm).to(tl.float32) + b_res = tl.load(b_ptr + (2 * HC + flat)).to(tl.float32) + logits = mix_res * alpha_res + b_res + + # Initial row-softmax + row_max = tl.max(logits, axis=1) + e = tl.exp(logits - row_max[:, None]) + row_sum0 = tl.sum(e, axis=1) + p = e / row_sum0[:, None] + p_eps = p + sinkhorn_eps + + col_sum0 = tl.sum(p_eps, axis=0) + mat0 = p_eps / (col_sum0[None, :] + sinkhorn_eps) + + # Start backward from grad_out + g = tl.load( + grad_out_ptr + pid * stride_go_n + rows * stride_go_i + cols * stride_go_j, + ).to(tl.float32) + + # Reverse iterations (TMAX-1 .. 1) using stored mats + for t in tl.static_range(TMAX - 1, 0, -1): + mat_t = tl.load(hist_ptr + pid * stride_hn + t * stride_ht + rows * stride_hi + cols * stride_hj).to(tl.float32) + mat_prev = tl.load(hist_ptr + pid * stride_hn + (t - 1) * stride_ht + rows * stride_hi + cols * stride_hj).to( + tl.float32 + ) + + row_sum = tl.sum(mat_prev, axis=1) + mat_row = mat_prev / (row_sum[:, None] + sinkhorn_eps) + col_sum = tl.sum(mat_row, axis=0) + denom_col = col_sum + sinkhorn_eps + + dot_col = tl.sum(g * mat_t, axis=0) + g_row = (g - dot_col[None, :]) / denom_col[None, :] + + m_row = mat_t * denom_col[None, :] + denom_row = row_sum + sinkhorn_eps + dot_row = tl.sum(g_row * m_row, axis=1) + g = (g_row - dot_row[:, None]) / denom_row[:, None] + + # Undo initial col norm (t=0) + denom_col0 = col_sum0 + sinkhorn_eps + dot_col0 = tl.sum(g * mat0, axis=0) + g_p = (g - dot_col0[None, :]) / denom_col0[None, :] + + # Softmax backward on rows: p * (g_p - sum(g_p * p)) + dot_soft = tl.sum(g_p * p, axis=1) + grad_logits = p * (g_p - dot_soft[:, None]) + + tl.store(grad_logits_ptr + pid * stride_gl_n + rows * stride_gl_i + cols * stride_gl_j, grad_logits) + + +def mhc_split_sinkhorn_fwd( + mix: torch.Tensor, + b: torch.Tensor, + alpha_pre: torch.Tensor, + alpha_post: torch.Tensor, + alpha_res: torch.Tensor, + *, + tmax: int, + pre_eps: float, + sinkhorn_eps: float, + post_mult: float, + out_hpre: Optional[torch.Tensor] = None, + out_hpost: Optional[torch.Tensor] = None, + out_hres: Optional[torch.Tensor] = None, + out_hist: Optional[torch.Tensor] = None, + return_hist: bool = False, + num_warps: int = 1, +) -> Union[ + Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] +]: + """ + Compute h_pre, h_post, h_res from `mix` (already normalized by RMS scalar). + + mix: [N, M] float32 where M = HC*HC + 2*HC + b: [M] float32 + """ + assert mix.is_contiguous() and b.is_contiguous() + + N, M = mix.shape + assert M == b.numel() + # infer HC from M = HC*HC + 2*HC + # Solve HC^2 + 2HC - M = 0 + HC = int((math.isqrt(4 + 4 * M) - 2) // 2) + assert HC * HC + 2 * HC == M, f"Invalid M for mHC: M={M}" + + if out_hpre is None: + out_hpre = torch.empty((N, HC), device=mix.device, dtype=torch.float32) + if out_hpost is None: + out_hpost = torch.empty((N, HC), device=mix.device, dtype=torch.float32) + if out_hres is None: + out_hres = torch.empty((N, HC, HC), device=mix.device, dtype=torch.float32) + if return_hist: + if out_hist is None: + out_hist = torch.empty((N, tmax, HC, HC), device=mix.device, dtype=torch.float32) + else: + if out_hist is None: + out_hist = torch.empty((1,), device=mix.device, dtype=torch.float32) + + grid = (N,) + + _mhc_split_sinkhorn_fwd_kernel[grid]( + mix, + b, + out_hpre, + out_hpost, + out_hres, + out_hist, + N=N, + HC=HC, + M=M, + stride_mn=mix.stride(0), + stride_mm=mix.stride(1), + stride_hp_n=out_hpre.stride(0), + stride_hp_h=out_hpre.stride(1), + stride_hq_n=out_hpost.stride(0), + stride_hq_h=out_hpost.stride(1), + stride_hr_n=out_hres.stride(0), + stride_hr_i=out_hres.stride(1), + stride_hr_j=out_hres.stride(2), + stride_hn=out_hist.stride(0) if out_hist.ndim > 1 else 0, + stride_ht=out_hist.stride(1) if out_hist.ndim > 1 else 0, + stride_hi=out_hist.stride(2) if out_hist.ndim > 1 else 0, + stride_hj=out_hist.stride(3) if out_hist.ndim > 1 else 0, + alpha_pre_ptr=alpha_pre.contiguous(), + alpha_post_ptr=alpha_post.contiguous(), + alpha_res_ptr=alpha_res.contiguous(), + pre_eps=pre_eps, + sinkhorn_eps=sinkhorn_eps, + post_mult=post_mult, + TMAX=tmax, + STORE_HIST=return_hist, + num_warps=num_warps, + ) + if return_hist: + return out_hpre, out_hpost, out_hres, out_hist + return out_hpre, out_hpost, out_hres + + +def mhc_sinkhorn_bwd( + mix: torch.Tensor, + b: torch.Tensor, + alpha_res: torch.Tensor, + grad_hres: torch.Tensor, + *, + tmax: int, + sinkhorn_eps: float, + hist: Optional[torch.Tensor] = None, + out_grad_logits: Optional[torch.Tensor] = None, + num_warps: int = 1, +) -> torch.Tensor: + """ + Backward for Sinkhorn: returns grad_logits (same shape as h_res). + + mix: [N, M] float32 + b: [M] float32 + grad_hres: [N, HC, HC] float32 + """ + assert mix.is_contiguous() and b.is_contiguous() and grad_hres.is_contiguous() + + N, M = mix.shape + HC = grad_hres.shape[1] + assert grad_hres.shape == (N, HC, HC) + assert M == HC * HC + 2 * HC + + if out_grad_logits is None: + out_grad_logits = torch.empty((N, HC, HC), device=mix.device, dtype=torch.float32) + + grid = (N,) + + alpha_res_c = alpha_res.contiguous() + + if hist is not None: + assert hist.is_contiguous() + assert hist.shape == (N, tmax, HC, HC) + _mhc_sinkhorn_bwd_hist_kernel[grid]( + mix, + b, + hist, + grad_hres, + out_grad_logits, + N=N, + HC=HC, + stride_mn=mix.stride(0), + stride_mm=mix.stride(1), + stride_hn=hist.stride(0), + stride_ht=hist.stride(1), + stride_hi=hist.stride(2), + stride_hj=hist.stride(3), + stride_go_n=grad_hres.stride(0), + stride_go_i=grad_hres.stride(1), + stride_go_j=grad_hres.stride(2), + stride_gl_n=out_grad_logits.stride(0), + stride_gl_i=out_grad_logits.stride(1), + stride_gl_j=out_grad_logits.stride(2), + alpha_res_ptr=alpha_res_c, + sinkhorn_eps=sinkhorn_eps, + TMAX=tmax, + num_warps=num_warps, + ) + else: + _mhc_sinkhorn_bwd_kernel[grid]( + mix, + b, + grad_hres, + out_grad_logits, + N=N, + HC=HC, + stride_mn=mix.stride(0), + stride_mm=mix.stride(1), + stride_go_n=grad_hres.stride(0), + stride_go_i=grad_hres.stride(1), + stride_go_j=grad_hres.stride(2), + stride_gl_n=out_grad_logits.stride(0), + stride_gl_i=out_grad_logits.stride(1), + stride_gl_j=out_grad_logits.stride(2), + alpha_res_ptr=alpha_res_c, + sinkhorn_eps=sinkhorn_eps, + TMAX=tmax, + num_warps=num_warps, + ) + return out_grad_logits + + +# ------------------------------------------------------------------------------------------------- +# Apply kernels: mhc_pre and mhc_post_res (forward + backward) +# ------------------------------------------------------------------------------------------------- + + +@triton.jit +def _mhc_pre_fwd_kernel( + x_ptr, + hpre_ptr, + out_ptr, + N: tl.constexpr, + HC: tl.constexpr, + C: tl.constexpr, + stride_xn: tl.constexpr, + stride_xh: tl.constexpr, + stride_xc: tl.constexpr, + stride_hn: tl.constexpr, + stride_hh: tl.constexpr, + stride_on: tl.constexpr, + stride_oc: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_C: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_c = tl.program_id(1) + + n_offs = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + c_offs = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + + acc = tl.zeros((BLOCK_N, BLOCK_C), tl.float32) + for s in tl.static_range(0, HC): + h_s = tl.load( + hpre_ptr + n_offs * stride_hn + s * stride_hh, + mask=(n_offs < N), + other=0.0, + ).to(tl.float32) + xs = tl.load( + x_ptr + n_offs[:, None] * stride_xn + s * stride_xh + c_offs[None, :] * stride_xc, + mask=(n_offs[:, None] < N) & (c_offs[None, :] < C), + other=0.0, + ).to(tl.float32) + acc += xs * h_s[:, None] + + tl.store( + out_ptr + n_offs[:, None] * stride_on + c_offs[None, :] * stride_oc, + acc, + mask=(n_offs[:, None] < N) & (c_offs[None, :] < C), + ) + + +@triton.jit +def _mhc_pre_bwd_kernel( + x_ptr, + hpre_ptr, + grad_out_ptr, + grad_x_ptr, + grad_h_ptr, + N: tl.constexpr, + HC: tl.constexpr, + C: tl.constexpr, + stride_xn: tl.constexpr, + stride_xh: tl.constexpr, + stride_xc: tl.constexpr, + stride_hn: tl.constexpr, + stride_hh: tl.constexpr, + stride_gon: tl.constexpr, + stride_goc: tl.constexpr, + stride_gxn: tl.constexpr, + stride_gxh: tl.constexpr, + stride_gxc: tl.constexpr, + stride_ghn: tl.constexpr, + stride_ghh: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_C: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_c = tl.program_id(1) + + n_offs = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + c_offs = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + + go = tl.load( + grad_out_ptr + n_offs[:, None] * stride_gon + c_offs[None, :] * stride_goc, + mask=(n_offs[:, None] < N) & (c_offs[None, :] < C), + other=0.0, + ).to(tl.float32) + + # grad_x = grad_out * hpre + for s in tl.static_range(0, HC): + h_s = tl.load( + hpre_ptr + n_offs * stride_hn + s * stride_hh, + mask=(n_offs < N), + other=0.0, + ).to(tl.float32) + gx = go * h_s[:, None] + tl.store( + grad_x_ptr + n_offs[:, None] * stride_gxn + s * stride_gxh + c_offs[None, :] * stride_gxc, + gx, + mask=(n_offs[:, None] < N) & (c_offs[None, :] < C), + ) + + # grad_hpre: dot(go, x_s) over C -> atomic add + xs = tl.load( + x_ptr + n_offs[:, None] * stride_xn + s * stride_xh + c_offs[None, :] * stride_xc, + mask=(n_offs[:, None] < N) & (c_offs[None, :] < C), + other=0.0, + ).to(tl.float32) + part = tl.sum(go * xs, axis=1) + tl.atomic_add( + grad_h_ptr + n_offs * stride_ghn + s * stride_ghh, + part, + mask=n_offs < N, + ) + + +def mhc_pre_fwd( + x: torch.Tensor, + h_pre: torch.Tensor, + *, + out: Optional[torch.Tensor] = None, + block_n: int = 32, + block_c: int = 128, + num_warps: int = 4, +) -> torch.Tensor: + assert x.is_contiguous() and h_pre.is_contiguous() + N, HC, C = x.shape + assert h_pre.shape == (N, HC) + + if out is None: + out = torch.empty((N, C), device=x.device, dtype=torch.float32) + + grid = (triton.cdiv(N, block_n), triton.cdiv(C, block_c)) + _mhc_pre_fwd_kernel[grid]( + x, + h_pre, + out, + N=N, + HC=HC, + C=C, + stride_xn=x.stride(0), + stride_xh=x.stride(1), + stride_xc=x.stride(2), + stride_hn=h_pre.stride(0), + stride_hh=h_pre.stride(1), + stride_on=out.stride(0), + stride_oc=out.stride(1), + BLOCK_N=block_n, + BLOCK_C=block_c, + num_warps=num_warps, + ) + return out + + +def mhc_pre_bwd( + x: torch.Tensor, + h_pre: torch.Tensor, + grad_out: torch.Tensor, + *, + out_grad_x: Optional[torch.Tensor] = None, + out_grad_h: Optional[torch.Tensor] = None, + block_n: int = 32, + block_c: int = 128, + num_warps: int = 4, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.is_contiguous() and h_pre.is_contiguous() and grad_out.is_contiguous() + N, HC, C = x.shape + assert grad_out.shape == (N, C) + + if out_grad_x is None: + out_grad_x = torch.empty_like(x, dtype=torch.float32) + if out_grad_h is None: + out_grad_h = torch.zeros((N, HC), device=x.device, dtype=torch.float32) + + grid = (triton.cdiv(N, block_n), triton.cdiv(C, block_c)) + _mhc_pre_bwd_kernel[grid]( + x, + h_pre, + grad_out, + out_grad_x, + out_grad_h, + N=N, + HC=HC, + C=C, + stride_xn=x.stride(0), + stride_xh=x.stride(1), + stride_xc=x.stride(2), + stride_hn=h_pre.stride(0), + stride_hh=h_pre.stride(1), + stride_gon=grad_out.stride(0), + stride_goc=grad_out.stride(1), + stride_gxn=out_grad_x.stride(0), + stride_gxh=out_grad_x.stride(1), + stride_gxc=out_grad_x.stride(2), + stride_ghn=out_grad_h.stride(0), + stride_ghh=out_grad_h.stride(1), + BLOCK_N=block_n, + BLOCK_C=block_c, + num_warps=num_warps, + ) + return out_grad_x, out_grad_h + + +@triton.jit +def _mhc_post_res_fwd_kernel( + x_ptr, + f_ptr, + hpost_ptr, + hres_ptr, + out_ptr, + N: tl.constexpr, + HC: tl.constexpr, + C: tl.constexpr, + stride_xn: tl.constexpr, + stride_xh: tl.constexpr, + stride_xc: tl.constexpr, + stride_fn: tl.constexpr, + stride_fc: tl.constexpr, + stride_hpn: tl.constexpr, + stride_hph: tl.constexpr, + stride_hrn: tl.constexpr, + stride_hri: tl.constexpr, + stride_hrj: tl.constexpr, + stride_on: tl.constexpr, + stride_oh: tl.constexpr, + stride_oc: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_C: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_c = tl.program_id(1) + + n_offs = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + c_offs = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + + f = tl.load( + f_ptr + n_offs[:, None] * stride_fn + c_offs[None, :] * stride_fc, + mask=(n_offs[:, None] < N) & (c_offs[None, :] < C), + other=0.0, + ).to(tl.float32) + + o2 = tl.arange(0, HC)[:, None] # [HC,1] + hpost = tl.load( + hpost_ptr + n_offs[None, :] * stride_hpn + o2 * stride_hph, + mask=(n_offs[None, :] < N), + other=0.0, + ).to(tl.float32) # [HC, BN] + + acc = f[None, :, :] * hpost[:, :, None] # [HC, BN, BC] + + # residual mixing: sum_i hres[o,i] * x_i + for i in tl.static_range(0, HC): + xs = tl.load( + x_ptr + n_offs[:, None] * stride_xn + i * stride_xh + c_offs[None, :] * stride_xc, + mask=(n_offs[:, None] < N) & (c_offs[None, :] < C), + other=0.0, + ).to(tl.float32) # [BN, BC] + w = tl.load( + hres_ptr + n_offs[None, :] * stride_hrn + o2 * stride_hri + i * stride_hrj, + mask=(n_offs[None, :] < N), + other=0.0, + ).to(tl.float32) # [HC, BN] + acc += xs[None, :, :] * w[:, :, None] + + o3 = tl.arange(0, HC)[:, None, None] + n3 = n_offs[None, :, None] + c3 = c_offs[None, None, :] + tl.store( + out_ptr + n3 * stride_on + o3 * stride_oh + c3 * stride_oc, + acc, + mask=(n3 < N) & (c3 < C), + ) + + +@triton.jit +def _mhc_post_res_bwd_kernel( + x_ptr, + f_ptr, + hpost_ptr, + hres_ptr, + grad_out_ptr, + grad_x_ptr, + grad_f_ptr, + grad_hpost_ptr, + grad_hres_ptr, + N: tl.constexpr, + HC: tl.constexpr, + C: tl.constexpr, + stride_xn: tl.constexpr, + stride_xh: tl.constexpr, + stride_xc: tl.constexpr, + stride_fn: tl.constexpr, + stride_fc: tl.constexpr, + stride_hpn: tl.constexpr, + stride_hph: tl.constexpr, + stride_hrn: tl.constexpr, + stride_hri: tl.constexpr, + stride_hrj: tl.constexpr, + stride_gon: tl.constexpr, + stride_goh: tl.constexpr, + stride_goc: tl.constexpr, + stride_gxn: tl.constexpr, + stride_gxh: tl.constexpr, + stride_gxc: tl.constexpr, + stride_gfn: tl.constexpr, + stride_gfc: tl.constexpr, + stride_ghpn: tl.constexpr, + stride_ghph: tl.constexpr, + stride_ghrn: tl.constexpr, + stride_ghri: tl.constexpr, + stride_ghrj: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_C: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_c = tl.program_id(1) + + n_offs = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + c_offs = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + + f = tl.load( + f_ptr + n_offs[:, None] * stride_fn + c_offs[None, :] * stride_fc, + mask=(n_offs[:, None] < N) & (c_offs[None, :] < C), + other=0.0, + ).to(tl.float32) + + o2 = tl.arange(0, HC)[:, None] # [HC,1] + hpost = tl.load( + hpost_ptr + n_offs[None, :] * stride_hpn + o2 * stride_hph, + mask=(n_offs[None, :] < N), + other=0.0, + ).to(tl.float32) # [HC, BN] + + o3 = tl.arange(0, HC)[:, None, None] + n3 = n_offs[None, :, None] + c3 = c_offs[None, None, :] + go = tl.load( + grad_out_ptr + n3 * stride_gon + o3 * stride_goh + c3 * stride_goc, + mask=(n3 < N) & (c3 < C), + other=0.0, + ).to(tl.float32) # [HC, BN, BC] + + # grad_f: sum_o go[o] * hpost[o] + gf = tl.sum(go * hpost[:, :, None], axis=0) + tl.store( + grad_f_ptr + n_offs[:, None] * stride_gfn + c_offs[None, :] * stride_gfc, + gf, + mask=(n_offs[:, None] < N) & (c_offs[None, :] < C), + ) + + # grad_hpost: dot(go[o], f) over C (atomic over C blocks) + part_hpost = tl.sum(go * f[None, :, :], axis=2) # [HC, BN] + tl.atomic_add( + grad_hpost_ptr + n_offs[None, :] * stride_ghpn + o2 * stride_ghph, + part_hpost, + mask=(n_offs[None, :] < N), + ) + + # grad_x: hres^T @ go (in-stream i gets sum_o hres[o,i] * go[o]) + for i in tl.static_range(0, HC): + w = tl.load( + hres_ptr + n_offs[None, :] * stride_hrn + o2 * stride_hri + i * stride_hrj, + mask=(n_offs[None, :] < N), + other=0.0, + ).to(tl.float32) # [HC, BN] + gx = tl.sum(go * w[:, :, None], axis=0) # [BN, BC] + tl.store( + grad_x_ptr + n_offs[:, None] * stride_gxn + i * stride_gxh + c_offs[None, :] * stride_gxc, + gx, + mask=(n_offs[:, None] < N) & (c_offs[None, :] < C), + ) + + # grad_hres[o,i]: dot(go[o], x[i]) over C (atomic) + for i in tl.static_range(0, HC): + xi = tl.load( + x_ptr + n_offs[:, None] * stride_xn + i * stride_xh + c_offs[None, :] * stride_xc, + mask=(n_offs[:, None] < N) & (c_offs[None, :] < C), + other=0.0, + ).to(tl.float32) + part_hres = tl.sum(go * xi[None, :, :], axis=2) # [HC, BN] + tl.atomic_add( + grad_hres_ptr + n_offs[None, :] * stride_ghrn + o2 * stride_ghri + i * stride_ghrj, + part_hres, + mask=(n_offs[None, :] < N), + ) + + +def mhc_post_res_fwd( + x: torch.Tensor, + f_out: torch.Tensor, + h_post: torch.Tensor, + h_res: torch.Tensor, + *, + out: Optional[torch.Tensor] = None, + block_n: Optional[int] = None, + block_c: Optional[int] = None, + num_warps: Optional[int] = None, + num_stages: Optional[int] = None, +) -> torch.Tensor: + assert x.is_contiguous() and f_out.is_contiguous() and h_post.is_contiguous() and h_res.is_contiguous() + + N, HC, C = x.shape + assert f_out.shape == (N, C) + assert h_post.shape == (N, HC) + assert h_res.shape == (N, HC, HC) + + if out is None: + out = torch.empty((N, HC, C), device=x.device, dtype=torch.float32) + + block_n, block_c, num_warps, num_stages = _post_res_meta(C, block_n, block_c, num_warps, num_stages) + + grid = (triton.cdiv(N, block_n), triton.cdiv(C, block_c)) + _mhc_post_res_fwd_kernel[grid]( + x, + f_out, + h_post, + h_res, + out, + N=N, + HC=HC, + C=C, + stride_xn=x.stride(0), + stride_xh=x.stride(1), + stride_xc=x.stride(2), + stride_fn=f_out.stride(0), + stride_fc=f_out.stride(1), + stride_hpn=h_post.stride(0), + stride_hph=h_post.stride(1), + stride_hrn=h_res.stride(0), + stride_hri=h_res.stride(1), + stride_hrj=h_res.stride(2), + stride_on=out.stride(0), + stride_oh=out.stride(1), + stride_oc=out.stride(2), + BLOCK_N=block_n, + BLOCK_C=block_c, + num_warps=num_warps, + num_stages=num_stages, + ) + return out + + +def mhc_post_res_bwd( + x: torch.Tensor, + f_out: torch.Tensor, + h_post: torch.Tensor, + h_res: torch.Tensor, + grad_out: torch.Tensor, + *, + out_grad_x: Optional[torch.Tensor] = None, + out_grad_f: Optional[torch.Tensor] = None, + out_grad_hpost: Optional[torch.Tensor] = None, + out_grad_hres: Optional[torch.Tensor] = None, + block_n: Optional[int] = None, + block_c: Optional[int] = None, + num_warps: Optional[int] = None, + num_stages: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + assert ( + x.is_contiguous() + and f_out.is_contiguous() + and h_post.is_contiguous() + and h_res.is_contiguous() + and grad_out.is_contiguous() + ) + + N, HC, C = x.shape + assert grad_out.shape == (N, HC, C) + + if out_grad_x is None: + out_grad_x = torch.empty_like(x, dtype=torch.float32) + if out_grad_f is None: + out_grad_f = torch.empty_like(f_out, dtype=torch.float32) + if out_grad_hpost is None: + out_grad_hpost = torch.zeros((N, HC), device=x.device, dtype=torch.float32) + if out_grad_hres is None: + out_grad_hres = torch.zeros((N, HC, HC), device=x.device, dtype=torch.float32) + + block_n, block_c, num_warps, num_stages = _post_res_meta(C, block_n, block_c, num_warps, num_stages) + + grid = (triton.cdiv(N, block_n), triton.cdiv(C, block_c)) + _mhc_post_res_bwd_kernel[grid]( + x, + f_out, + h_post, + h_res, + grad_out, + out_grad_x, + out_grad_f, + out_grad_hpost, + out_grad_hres, + N=N, + HC=HC, + C=C, + stride_xn=x.stride(0), + stride_xh=x.stride(1), + stride_xc=x.stride(2), + stride_fn=f_out.stride(0), + stride_fc=f_out.stride(1), + stride_hpn=h_post.stride(0), + stride_hph=h_post.stride(1), + stride_hrn=h_res.stride(0), + stride_hri=h_res.stride(1), + stride_hrj=h_res.stride(2), + stride_gon=grad_out.stride(0), + stride_goh=grad_out.stride(1), + stride_goc=grad_out.stride(2), + stride_gxn=out_grad_x.stride(0), + stride_gxh=out_grad_x.stride(1), + stride_gxc=out_grad_x.stride(2), + stride_gfn=out_grad_f.stride(0), + stride_gfc=out_grad_f.stride(1), + stride_ghpn=out_grad_hpost.stride(0), + stride_ghph=out_grad_hpost.stride(1), + stride_ghrn=out_grad_hres.stride(0), + stride_ghri=out_grad_hres.stride(1), + stride_ghrj=out_grad_hres.stride(2), + BLOCK_N=block_n, + BLOCK_C=block_c, + num_warps=num_warps, + num_stages=num_stages, + ) + return out_grad_x, out_grad_f, out_grad_hpost, out_grad_hres + + +def _flatten_tokens(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Size]: + """ + Flattens leading dimensions so x becomes [N, HC, C]. + Returns (x_flat, x_shape) where x_shape is the original shape. + """ + assert x.dim() >= 3, "x must be [..., HC, C]" + return x.contiguous().view(-1, x.shape[-2], x.shape[-1]), x.shape + + +class LigerMHCCoeffsFunction(torch.autograd.Function): + """ + Autograd function for mHC coefficient computation. + + Memory/Compute Trade-off: + When gradients are needed, Sinkhorn iteration history (hist) is saved + during forward to avoid recomputation in backward. This increases + memory usage by O(N * tmax * HC^2) but reduces backward compute. + """ + + @staticmethod + @ensure_contiguous + def forward( # type: ignore[override] + ctx: Any, + x: torch.Tensor, # [..., HC, C] bf16/fp16 (or fp32 if allow_fp32) + phi: torch.Tensor, # [HC*C, M] + b: torch.Tensor, # [M] + alpha_pre: torch.Tensor, # scalar + alpha_post: torch.Tensor, # scalar + alpha_res: torch.Tensor, # scalar + allow_fp32: bool, + tmax: int, + rms_eps: float, + pre_eps: float, + sinkhorn_eps: float, + post_mult: float, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if allow_fp32: + assert x.dtype in ( + torch.bfloat16, + torch.float16, + torch.float32, + ), "x should be BF16/FP16/FP32 when allow_fp32=True" + else: + assert x.dtype in (torch.bfloat16, torch.float16), "x should be BF16/FP16 (set allow_fp32=True for FP32)" + # Store original shape for restoring at the end + x_shape = x.shape + x_flat, _ = _flatten_tokens(x) + N, HC, C = x_flat.shape + K = HC * C + x_mat = x_flat.view(-1, K) + + assert phi.dim() == 2 and phi.shape[0] == K, f"phi must be [HC*C, M], got {tuple(phi.shape)}" + M = int(phi.shape[1]) + assert b.shape == (M,), f"b must be [M], got {tuple(b.shape)}" + + # (1) fused coeff matmul + norm + mix, invr = mhc_mm_norm_fwd(x_mat, phi, eps=float(rms_eps)) + + # (2) split + sigmoid + sinkhorn + need_hist = any(ctx.needs_input_grad) + if need_hist: + h_pre, h_post, h_res, hist = mhc_split_sinkhorn_fwd( + mix, + b, + alpha_pre, + alpha_post, + alpha_res, + tmax=int(tmax), + pre_eps=float(pre_eps), + sinkhorn_eps=float(sinkhorn_eps), + post_mult=float(post_mult), + return_hist=True, + ) + else: + h_pre, h_post, h_res = mhc_split_sinkhorn_fwd( + mix, + b, + alpha_pre, + alpha_post, + alpha_res, + tmax=int(tmax), + pre_eps=float(pre_eps), + sinkhorn_eps=float(sinkhorn_eps), + post_mult=float(post_mult), + ) + hist = None + + # Save for backward + if hist is not None: + ctx.save_for_backward(x_mat, phi, b, mix, invr, alpha_pre, alpha_post, alpha_res, hist) + else: + ctx.save_for_backward(x_mat, phi, b, mix, invr, alpha_pre, alpha_post, alpha_res) + ctx.meta = ( + x_shape, + HC, + C, + int(tmax), + float(sinkhorn_eps), + float(post_mult), + hist is not None, + ) + + # Reshape to original leading dims + outer = x_shape[:-2] + return ( + h_pre.view(*outer, HC), + h_post.view(*outer, HC), + h_res.view(*outer, HC, HC), + ) + + @staticmethod + @ensure_contiguous + def backward( + ctx: Any, + grad_h_pre: torch.Tensor | None, + grad_h_post: torch.Tensor | None, + grad_h_res: torch.Tensor | None, + ): + saved = ctx.saved_tensors + x_shape, HC, C, tmax, sinkhorn_eps, post_mult, has_hist = ctx.meta + if has_hist: + x_mat, phi, b, mix, invr, alpha_pre, alpha_post, alpha_res, hist = saved + else: + x_mat, phi, b, mix, invr, alpha_pre, alpha_post, alpha_res = saved + hist = None + N = x_mat.shape[0] + M = mix.shape[1] + assert M == HC * HC + 2 * HC + + need_pre = grad_h_pre is not None + need_post = grad_h_post is not None + need_res = grad_h_res is not None + + # flatten grads (None -> zeros) + if need_pre: + gh_pre = grad_h_pre.view(-1, HC).to(torch.float32) + else: + gh_pre = torch.zeros((N, HC), device=mix.device, dtype=torch.float32) + if need_post: + gh_post = grad_h_post.view(-1, HC).to(torch.float32) + else: + gh_post = torch.zeros((N, HC), device=mix.device, dtype=torch.float32) + if need_res: + gh_res = grad_h_res.view(-1, HC, HC).to(torch.float32) + else: + gh_res = torch.zeros((N, HC, HC), device=mix.device, dtype=torch.float32) + + # --- Sinkhorn backward -> grad logits for residual matrix + if need_res: + grad_res_logits = mhc_sinkhorn_bwd( + mix, + b, + alpha_res, + gh_res, + tmax=tmax, + sinkhorn_eps=sinkhorn_eps, + hist=hist, + ) # [N, HC, HC] fp32 + else: + grad_res_logits = gh_res + + # --- Pre/post derivatives (sigmoid) + mix_pre = mix[:, :HC] + mix_post = mix[:, HC : 2 * HC] + mix_res = mix[:, 2 * HC :] + + b_pre = b[:HC] + b_post = b[HC : 2 * HC] + if need_pre: + pre_logits = mix_pre * alpha_pre + b_pre + pre_sig = torch.sigmoid(pre_logits) + grad_pre_logits = gh_pre * (pre_sig * (1.0 - pre_sig)) # [N,HC] + else: + grad_pre_logits = gh_pre + + if need_post: + post_logits = mix_post * alpha_post + b_post + post_sig = torch.sigmoid(post_logits) + grad_post_logits = gh_post * (post_mult * post_sig * (1.0 - post_sig)) # [N,HC] + else: + grad_post_logits = gh_post + + grad_res_logits_flat = grad_res_logits.reshape(N, HC * HC) + + # --- Grad w.r.t mix + grad_mix = torch.empty_like(mix) + grad_mix[:, :HC] = grad_pre_logits * alpha_pre + grad_mix[:, HC : 2 * HC] = grad_post_logits * alpha_post + grad_mix[:, 2 * HC :] = grad_res_logits_flat * alpha_res + + # --- Grad w.r.t b + grad_b = torch.zeros_like(b, dtype=torch.float32) + if need_pre: + grad_b[:HC] = grad_pre_logits.sum(dim=0) + if need_post: + grad_b[HC : 2 * HC] = grad_post_logits.sum(dim=0) + if need_res: + grad_b[2 * HC :] = grad_res_logits_flat.sum(dim=0) + + # --- Grad w.r.t alphas + if need_pre: + grad_alpha_pre = (grad_pre_logits * mix_pre).sum() + else: + grad_alpha_pre = torch.zeros((), device=mix.device, dtype=torch.float32) + if need_post: + grad_alpha_post = (grad_post_logits * mix_post).sum() + else: + grad_alpha_post = torch.zeros((), device=mix.device, dtype=torch.float32) + if need_res: + grad_alpha_res = (grad_res_logits_flat * mix_res).sum() + else: + grad_alpha_res = torch.zeros((), device=mix.device, dtype=torch.float32) + + # --- Grad w.r.t x and phi via fused mm+norm backward + grad_x_mat, grad_phi = mhc_mm_norm_bwd( + x_mat, + phi, + mix, + invr, + grad_mix, + ) + + # Reshape to original shape + grad_x = grad_x_mat.view(x_shape) + + # Return grads for each forward input + return ( + grad_x, # x + grad_phi, # phi + grad_b, # b + grad_alpha_pre, # alpha_pre + grad_alpha_post, # alpha_post + grad_alpha_res, # alpha_res + None, # allow_fp32 + None, # tmax + None, # rms_eps + None, # pre_eps + None, # sinkhorn_eps + None, # post_mult + ) + + +class LigerMHCPreFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx: Any, x: torch.Tensor, h_pre: torch.Tensor) -> torch.Tensor: + x_shape = x.shape + x_flat, _ = _flatten_tokens(x) + h_pre_flat = h_pre.view(-1, x_flat.shape[1]).to(torch.float32) + out = mhc_pre_fwd(x_flat, h_pre_flat) # [N,C] fp32 + ctx.save_for_backward(x_flat, h_pre_flat) + ctx.x_shape = x_shape + out = out.to(x_flat.dtype) + return out.view(*x_shape[:-2], out.shape[-1]) + + @staticmethod + @ensure_contiguous + def backward(ctx: Any, grad_out: torch.Tensor): + x_flat, h_pre_flat = ctx.saved_tensors + x_shape = ctx.x_shape + N, HC, C = x_flat.shape + go = grad_out.view(-1, C).to(torch.float32) + grad_x, grad_h = mhc_pre_bwd(x_flat, h_pre_flat, go) + grad_x = grad_x.to(x_flat.dtype) + return grad_x.view(*x_shape), grad_h.view(*x_shape[:-1]) + + +class LigerMHCPostResFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward( + ctx: Any, x: torch.Tensor, f_out: torch.Tensor, h_post: torch.Tensor, h_res: torch.Tensor + ) -> torch.Tensor: + x_shape = x.shape + x_flat, _ = _flatten_tokens(x) + N, HC, C = x_flat.shape + f_flat = f_out.view(-1, C) + h_post_flat = h_post.view(-1, HC).to(torch.float32) + h_res_flat = h_res.view(-1, HC, HC).to(torch.float32) + out = mhc_post_res_fwd(x_flat, f_flat, h_post_flat, h_res_flat) # [N,HC,C] fp32 + ctx.save_for_backward(x_flat, f_flat, h_post_flat, h_res_flat) + ctx.x_shape = x_shape + out = out.to(x_flat.dtype) + return out.view(*x_shape) + + @staticmethod + @ensure_contiguous + def backward(ctx: Any, grad_out: torch.Tensor): + x_flat, f_flat, h_post_flat, h_res_flat = ctx.saved_tensors + x_shape = ctx.x_shape + N, HC, C = x_flat.shape + go = grad_out.view(-1, HC, C).to(torch.float32) + + grad_x, grad_f, grad_hpost, grad_hres = mhc_post_res_bwd(x_flat, f_flat, h_post_flat, h_res_flat, go) + + outer = x_shape[:-2] + return ( + grad_x.to(x_flat.dtype).view(*x_shape), + grad_f.to(f_flat.dtype).view(*outer, C), + grad_hpost.view(*outer, HC), + grad_hres.view(*outer, HC, HC), + ) diff --git a/src/liger_kernel/ops/multi_token_attention.py b/src/liger_kernel/ops/multi_token_attention.py new file mode 100755 index 0000000000000000000000000000000000000000..a91ebf58a90d5c56b746a384e043cf1031522a85 --- /dev/null +++ b/src/liger_kernel/ops/multi_token_attention.py @@ -0,0 +1,207 @@ +import torch +import torch.nn.functional as F +import triton +import triton.language as tl + +from torch.nn.modules.utils import _pair + +from liger_kernel.ops.softmax import _softmax_forward +from liger_kernel.ops.sparsemax import _sparsemax_backward +from liger_kernel.ops.sparsemax import _sparsemax_forward +from liger_kernel.ops.utils import calculate_settings +from liger_kernel.ops.utils import ensure_contiguous + + +@triton.jit +def _mask_fwd_kernel( + scores_ptr, + out_ptr, + stride_b, + stride_m, + stride_n, + L, + mask_val: tl.constexpr, + BLOCK: tl.constexpr, + num_warps: tl.constexpr, +): + row_block = tl.program_id(0) + col_block = tl.program_id(1) + batch_id = tl.program_id(2) + + row_idx = row_block * BLOCK + tl.arange(0, BLOCK) + col_idx = col_block * BLOCK + tl.arange(0, BLOCK) + in_bounds = (row_idx[:, None] < L) & (col_idx[None, :] < L) + + base = scores_ptr + batch_id * stride_b + offs = row_idx[:, None] * stride_m + col_idx[None, :] * stride_n + future = col_idx[None, :] > row_idx[:, None] + mask_load = in_bounds & ~future + out = tl.load(base + offs, mask=mask_load, other=mask_val, cache_modifier=".ca") + tl.store(out_ptr + batch_id * stride_b + offs, out, mask=in_bounds, cache_modifier=".cs") + + +@triton.jit +def _mask_bwd_kernel( + grad_in_ptr, out_ptr, stride_b, stride_m, stride_n, L, BLOCK: tl.constexpr, num_warps: tl.constexpr +): + row_block = tl.program_id(0) + col_block = tl.program_id(1) + batch_id = tl.program_id(2) + + row_idx = row_block * BLOCK + tl.arange(0, BLOCK) + col_idx = col_block * BLOCK + tl.arange(0, BLOCK) + in_bounds = (row_idx[:, None] < L) & (col_idx[None, :] < L) + + base = grad_in_ptr + batch_id * stride_b + offs = row_idx[:, None] * stride_m + col_idx[None, :] * stride_n + grad_vals = tl.load(base + offs, mask=in_bounds, other=0.0, cache_modifier=".ca") + + future = col_idx[None, :] > row_idx[:, None] + zero = tl.zeros(grad_vals.shape, dtype=grad_vals.dtype) + out = tl.where(future, zero, grad_vals) + + tl.store(out_ptr + batch_id * stride_b + offs, out, mask=in_bounds, cache_modifier=".wb") + + +def _mask_inf_forward(scores: torch.Tensor) -> torch.Tensor: + *batch, L, _ = scores.shape + N = int(torch.prod(torch.tensor(batch))) if batch else 1 + scores_f = scores.view(N, L, L) + out = torch.empty_like(scores_f) + + sb, sm, sn = scores_f.stride(0), scores_f.stride(1), scores_f.stride(2) + BLOCK_SIZE, num_warps = calculate_settings(L) + grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N) + _mask_fwd_kernel[grid](scores_f, out, sb, sm, sn, L, mask_val=-1e9, BLOCK=BLOCK_SIZE, num_warps=num_warps) + return out.view(*batch, L, L) + + +def _mask_inf_backward(grad: torch.Tensor) -> torch.Tensor: + *batch, L, _ = grad.shape + N = int(torch.prod(torch.tensor(batch))) if batch else 1 + grad_f = grad.view(N, L, L) + out = torch.empty_like(grad_f) + + sb, sm, sn = grad_f.stride(0), grad_f.stride(1), grad_f.stride(2) + BLOCK_SIZE, num_warps = calculate_settings(L) + grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N) + _mask_bwd_kernel[grid](grad_f, out, sb, sm, sn, L, BLOCK=BLOCK_SIZE, num_warps=num_warps) + return out.view(*batch, L, L) + + +def _mask_zero_forward(scores: torch.Tensor) -> torch.Tensor: + *batch, L, _ = scores.shape + N = int(torch.prod(torch.tensor(batch))) if batch else 1 + scores_f = scores.view(N, L, L) + out = torch.empty_like(scores_f) + + sb, sm, sn = scores_f.stride(0), scores_f.stride(1), scores_f.stride(2) + BLOCK_SIZE, num_warps = calculate_settings(L) + grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N) + _mask_fwd_kernel[grid](scores_f, out, sb, sm, sn, L, mask_val=0.0, BLOCK=BLOCK_SIZE, num_warps=num_warps) + return out.view(*batch, L, L) + + +def _mask_zero_backward(grad: torch.Tensor) -> torch.Tensor: + *batch, L, _ = grad.shape + N = int(torch.prod(torch.tensor(batch))) if batch else 1 + grad_f = grad.view(N, L, L) + out = torch.empty_like(grad_f) + + sb, sm, sn = grad_f.stride(0), grad_f.stride(1), grad_f.stride(2) + BLOCK_SIZE, num_warps = calculate_settings(L) + grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N) + _mask_bwd_kernel[grid](grad_f, out, sb, sm, sn, L, BLOCK=BLOCK_SIZE, num_warps=num_warps) + return out.view(*batch, L, L) + + +class LigerMultiTokenAttentionFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, scores, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, sparse=False): + scores_inf = _mask_inf_forward(scores) + + out_flat_sparse = None + activation_output = None + + ctx.sparse = sparse + + if sparse: + if scores_inf.dtype != torch.float32: + raise RuntimeError("Liger sparse multi-token attention currently only supports fp32 input scores") + probs_sparse, out_flat_sparse = _sparsemax_forward(scores_inf, dim=-1) + activation_output = probs_sparse + ctx.save_for_backward(scores_inf, activation_output, out_flat_sparse, weight, bias) + ctx.out_flat_sparse_saved = True + else: + probs_softmax, _, _, _ = _softmax_forward(scores_inf) + activation_output = probs_softmax + ctx.save_for_backward(scores_inf, activation_output, weight, bias) + ctx.out_flat_sparse_saved = False + + out_conv = F.conv2d( + activation_output, + weight, + bias, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + out = _mask_zero_forward(out_conv) + + ctx.stride = _pair(stride) + ctx.padding = _pair(padding) + ctx.dilation = _pair(dilation) + ctx.groups = groups + ctx.dim = -1 + + return out + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_out): + if ctx.out_flat_sparse_saved: + scores_inf, activation_output, out_flat_sparse, weight, bias = ctx.saved_tensors + else: + scores_inf, activation_output, weight, bias = ctx.saved_tensors + out_flat_sparse = None + + use_sparsemax = ctx.sparse + dim = ctx.dim + stride, padding, dilation, groups = (ctx.stride, ctx.padding, ctx.dilation, ctx.groups) + + grad_conv = _mask_zero_backward(grad_out) + + grad_probs = F.conv_transpose2d( + grad_conv, weight, None, stride=stride, padding=padding, dilation=dilation, groups=groups + ) + + grad_weight = torch.nn.grad.conv2d_weight( + input=activation_output, + weight_size=weight.shape, + grad_output=grad_conv, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + grad_bias = None + if bias is not None: + grad_bias = grad_conv.sum(dim=(0, 2, 3)) + + grad_scores_inf = None + if use_sparsemax: + if not ctx.out_flat_sparse_saved or out_flat_sparse is None: + raise RuntimeError("Internal error: Sparse flag is set but sparse tensor was not saved.") + grad_scores_inf = _sparsemax_backward(grad_probs, out_flat_sparse, dim=dim) + else: + grad_probs_cont = grad_probs + probs_cont = activation_output + dot = (grad_probs_cont * probs_cont).sum(dim=-1, keepdim=True) + grad_scores_inf = probs_cont * (grad_probs_cont - dot) + + grad_scores = _mask_inf_backward(grad_scores_inf) + + return (grad_scores, grad_weight, grad_bias, None, None, None, None, None) diff --git a/src/liger_kernel/ops/poly_norm.py b/src/liger_kernel/ops/poly_norm.py new file mode 100755 index 0000000000000000000000000000000000000000..2198e522d2bc1acc61644bb8882253e4d7650268 --- /dev/null +++ b/src/liger_kernel/ops/poly_norm.py @@ -0,0 +1,384 @@ +import operator + +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.utils import calculate_settings +from liger_kernel.ops.utils import compare_version +from liger_kernel.ops.utils import ensure_contiguous +from liger_kernel.ops.utils import get_npu_core_count +from liger_kernel.ops.utils import set_large_grf_mode +from liger_kernel.utils import is_npu_available + +if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available(): + try: + from triton.language.extra.libdevice import rsqrt + except ModuleNotFoundError: + from triton.language.extra.cuda.libdevice import rsqrt +else: + from triton.language.math import rsqrt + + +@triton.jit +def _poly_norm_forward_kernel( + Y_ptr, + Y_row_stride, + X_ptr, + X_row_stride, + W_ptr, # weight: [3] for [w0, w1, w2] + B_ptr, # bias: scalar + RSTD_ptr, # cache rstd for backward: shape (n_rows, 3) + RSTD_row_stride, + n_cols, + eps, + BLOCK_SIZE: tl.constexpr, +): + """ + PolyNorm formula: + y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b + where norm(u) = u / sqrt(mean(u²) + ε) + + Reference: + 1. https://github.com/BryceZhuo/PolyCom/ + 2. https://arxiv.org/pdf/2411.03884 + + Cache rstd values for backward pass + """ + row_idx = tl.program_id(0).to(tl.int64) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + # Load pointers + Y_ptr += row_idx * Y_row_stride + X_ptr += row_idx * X_row_stride + RSTD_ptr += row_idx * RSTD_row_stride + + # Load input row + X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0) + + # Load weights and bias + w0 = tl.load(W_ptr + 0) + w1 = tl.load(W_ptr + 1) + w2 = tl.load(W_ptr + 2) + b = tl.load(B_ptr) + + # Compute x³, x², x + X_pow3 = X_row * X_row * X_row + X_pow2 = X_row * X_row + X_pow1 = X_row + + # Compute norm(x³): norm(u) = u * rsqrt(mean(u²) + eps) + mean_square_3 = tl.sum(X_pow3 * X_pow3, axis=0) / n_cols + rstd_3 = rsqrt(mean_square_3 + eps) + norm_x3 = X_pow3 * rstd_3 + + # Compute norm(x²) + mean_square_2 = tl.sum(X_pow2 * X_pow2, axis=0) / n_cols + rstd_2 = rsqrt(mean_square_2 + eps) + norm_x2 = X_pow2 * rstd_2 + + # Compute norm(x) + mean_square_1 = tl.sum(X_pow1 * X_pow1, axis=0) / n_cols + rstd_1 = rsqrt(mean_square_1 + eps) + norm_x1 = X_pow1 * rstd_1 + + # Cache rstd values for backward + tl.store(RSTD_ptr + 0, rstd_3) + tl.store(RSTD_ptr + 1, rstd_2) + tl.store(RSTD_ptr + 2, rstd_1) + + # Compute output: y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b + Y_row = w0 * norm_x3 + w1 * norm_x2 + w2 * norm_x1 + b + + # Store output + tl.store(Y_ptr + col_offsets, Y_row, mask=mask) + + +@triton.jit +def _poly_norm_backward_kernel( + dY_ptr, + dY_row_stride, + dX_ptr, + dX_row_stride, + X_ptr, + X_row_stride, + W_ptr, + RSTD_ptr, + RSTD_row_stride, + dW_ptr, # shape: (n_programs, 3) + dW_row_stride, + dB_ptr, # shape: (n_programs,) + n_rows, + n_cols, + rows_per_program: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + PolyNorm Backward Kernel Gradient: + ∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)] + + where: + - D_p = RMS(x^p) = 1/rstd_p + - S_p = sum(grad * x^p) over the row + - d = n_cols + - p ∈ {3, 2, 1} + """ + row_block_id = tl.program_id(0).to(tl.int64) + row_start = row_block_id * rows_per_program + row_end = min((row_block_id + 1) * rows_per_program, n_rows) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + # Initialize accumulators for weight and bias gradients (scalars) + dW0_acc = 0.0 + dW1_acc = 0.0 + dW2_acc = 0.0 + dB_acc = 0.0 + + # Load weights + w0 = tl.load(W_ptr + 0).to(tl.float32) + w1 = tl.load(W_ptr + 1).to(tl.float32) + w2 = tl.load(W_ptr + 2).to(tl.float32) + + for row_idx in range(row_start, row_end): + dy_base = dY_ptr + row_idx * dY_row_stride + x_base = X_ptr + row_idx * X_row_stride + dx_base = dX_ptr + row_idx * dX_row_stride + rstd_base = RSTD_ptr + row_idx * RSTD_row_stride + + dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0).to(tl.float32) + X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0).to(tl.float32) + + # Load cached rstd values + rstd_3 = tl.load(rstd_base + 0).to(tl.float32) + rstd_2 = tl.load(rstd_base + 1).to(tl.float32) + rstd_1 = tl.load(rstd_base + 2).to(tl.float32) + + # Compute powers + X_pow3 = X_row * X_row * X_row + X_pow2 = X_row * X_row + X_pow1 = X_row + + # Accumulate bias gradient: dB = sum(dY) + dB_acc += tl.sum(dY_row, axis=0) + + # Compute gradient w.r.t. input using closed-form formula + # For p=3: ∂L/∂x from w0 * norm(x³) + S_3 = tl.sum(dY_row * X_pow3, axis=0) # scalar + grad_x_3 = w0 * ( + 3.0 * X_pow2 * rstd_3 * dY_row + - (3.0 / n_cols) * X_row * X_row * X_row * X_row * X_row * (rstd_3 * rstd_3 * rstd_3) * S_3 + ) + + # For p=2: ∂L/∂x from w1 * norm(x²) + S_2 = tl.sum(dY_row * X_pow2, axis=0) # scalar + grad_x_2 = w1 * ( + 2.0 * X_row * rstd_2 * dY_row - (2.0 / n_cols) * X_row * X_row * X_row * (rstd_2 * rstd_2 * rstd_2) * S_2 + ) + + # For p=1: ∂L/∂x from w2 * norm(x) + S_1 = tl.sum(dY_row * X_pow1, axis=0) # scalar + grad_x_1 = w2 * (1.0 * rstd_1 * dY_row - (1.0 / n_cols) * X_row * (rstd_1 * rstd_1 * rstd_1) * S_1) + + # Accumulate weight gradients using closed-form: dW_p = rstd_p * S_p + dW0_acc += rstd_3 * S_3 + dW1_acc += rstd_2 * S_2 + dW2_acc += rstd_1 * S_1 + + # Total gradient + dX_row = grad_x_3 + grad_x_2 + grad_x_1 + + # Store gradient + tl.store(dx_base + col_offsets, dX_row, mask=mask) + + # Store accumulated gradients (scalars) + tl.store(dW_ptr + row_block_id * dW_row_stride + 0, dW0_acc) + tl.store(dW_ptr + row_block_id * dW_row_stride + 1, dW1_acc) + tl.store(dW_ptr + row_block_id * dW_row_stride + 2, dW2_acc) + tl.store(dB_ptr + row_block_id, dB_acc) + + +def poly_norm_forward(X, W, B, eps=1e-6): + """ + PolyNorm Forward Pass + + Args: + X: input tensor of shape (*, H) where H is hidden dimension + W: weight tensor of shape (3,) for [w0, w1, w2] + B: bias scalar tensor + eps: epsilon for numerical stability + + Returns: + Y: output tensor of same shape as X + X: reshaped input (for backward) + RSTD: cached rstd values (for backward) + BLOCK_SIZE: block size used + num_warps: number of warps used + """ + shape = X.shape + dim = shape[-1] + X = X.view(-1, dim) + n_rows, n_cols = X.shape + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + # RSTD is to cache rstd for each row + Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) + RSTD = torch.empty((n_rows, 3), dtype=torch.float32, device=X.device) + + # Check constraints + assert W.shape[0] == 3, "Weight tensor must have shape (3,)" + assert B.numel() == 1, "Bias must be a scalar" + + # XPU-specific optimization + kernel_args = {} + if X.device.type == "xpu": + set_large_grf_mode(kernel_args) + + # Launch kernel + _poly_norm_forward_kernel[(n_rows,)]( + Y, + Y.stride(0), + X, + X.stride(0), + W, + B, + RSTD, + RSTD.stride(0), + n_cols, + eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + **kernel_args, + ) + + return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps + + +def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place): + """ + PolyNorm Backward Pass + + Args: + dY: gradient of output + X: input tensor (already reshaped to 2D) + W: weight tensor + RSTD: cached rstd values from forward + BLOCK_SIZE: block size from forward + num_warps: number of warps from forward + in_place: whether to in-place modify dY to store dX (saves memory) + + Returns: + dX: gradient w.r.t. input + dW: gradient w.r.t. weight + dB: gradient w.r.t. bias + """ + shape = dY.shape + dim = shape[-1] + dY = dY.view(-1, dim) + n_rows, n_cols = dY.shape + + # Get number of SMs for parallelization + import math + + sm_count = 1 + if X.device.type == "cuda": + sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count + elif X.device.type == "xpu": + sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count + elif X.device.type == "npu": + sm_count = get_npu_core_count() + + # Allocate or reuse gradients + if in_place is True: + dX = dY + else: + dX = torch.zeros_like(dY) + + _dW = torch.empty((sm_count, 3), dtype=torch.float32, device=W.device) + _dB = torch.empty((sm_count,), dtype=torch.float32, device=W.device) + + rows_per_program = math.ceil(n_rows / sm_count) + grid = (sm_count,) + + # XPU-specific optimization + kernel_args = {} + if X.device.type == "xpu": + set_large_grf_mode(kernel_args) + + # Launch backward kernel + _poly_norm_backward_kernel[grid]( + dY, + dY.stride(0), + dX, + dX.stride(0), + X, + X.stride(0), + W, + RSTD, + RSTD.stride(0), + _dW, + _dW.stride(0), + _dB, + n_rows, + n_cols, + rows_per_program, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + **kernel_args, + ) + + # Reduce gradients across SMs + dX = dX.view(*shape) + dW = _dW.sum(dim=0).to(W.dtype) + dB = _dB.sum().to(W.dtype) + + return dX, dW, dB + + +class LigerPolyNormFunction(torch.autograd.Function): + """ + PolyNorm Function with forward and backward pass + + PolyNorm formula: + y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b + where norm(u) = u / sqrt(mean(u²) + ε) + + Backward uses closed-form gradient: + ∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)] + """ + + @staticmethod + @ensure_contiguous + def forward(ctx, X, W, B, eps=1e-6, in_place=True): + """ + Args: + X: input tensor of shape (B, T, H) or (BxT, H) + W: weight tensor of shape (3,) for [w0, w1, w2] + B: bias scalar + eps: epsilon for numerical stability + in_place: whether to in-place modify grad_output in backward (saves memory) + + Returns: + Y: output tensor of same shape as X + """ + Y, X, RSTD, BLOCK_SIZE, num_warps = poly_norm_forward(X, W, B, eps) + ctx.BLOCK_SIZE = BLOCK_SIZE + ctx.num_warps = num_warps + ctx.in_place = in_place + ctx.save_for_backward(X, W, RSTD) + return Y + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output): + """ + Args: + grad_output: gradient of output + + Returns: + dX, dW, dB: gradients w.r.t. X, W, B + """ + X, W, RSTD = ctx.saved_tensors + dX, dW, dB = poly_norm_backward(grad_output, X, W, RSTD, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place) + return dX, dW, dB, None, None diff --git a/src/liger_kernel/ops/qwen2vl_mrope.py b/src/liger_kernel/ops/qwen2vl_mrope.py new file mode 100755 index 0000000000000000000000000000000000000000..fbd120f96d4b0a7e81da06f118c0ba375976db31 --- /dev/null +++ b/src/liger_kernel/ops/qwen2vl_mrope.py @@ -0,0 +1,222 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _triton_qwen2vl_mrope( + q_ptr, + k_ptr, + cos, + sin, + sl, + bs: tl.constexpr, + n_qh: tl.constexpr, + n_kh: tl.constexpr, + hd: tl.constexpr, + pad_n_qh: tl.constexpr, + pad_n_kh: tl.constexpr, + pad_hd: tl.constexpr, + mrope_section_t: tl.constexpr, + mrope_section_h: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BACKWARD_PASS: tl.constexpr = False, +): + pid = tl.program_id(0) + + # locate start address + q_ptr = q_ptr + pid * (n_qh * hd) + k_ptr = k_ptr + pid * (n_kh * hd) + + # #################################################################### + # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position + # m of this program instance + # #################################################################### + + # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which + # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension + # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index + # and pid % sl to get the sequence index. + # 2. We only need the left half of cos and sin matrix because the right half is just + # a clone of the left half. + t_end = mrope_section_t + h_end = t_end + mrope_section_h + + t_cos = cos + pid * hd + h_cos = t_cos + bs * sl * hd + w_cos = h_cos + bs * sl * hd + t_sin = sin + pid * hd + h_sin = t_sin + bs * sl * hd + w_sin = h_sin + bs * sl * hd + + cos_offsets = tl.arange(0, pad_hd // 2) + t_mask = cos_offsets < t_end + h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end) + w_mask = (h_end <= cos_offsets) & (cos_offsets < hd // 2) + t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0) + h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0) + w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0) + t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0) + h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0) + w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0) + cos_row = t_cos_row + h_cos_row + w_cos_row + sin_row = t_sin_row + h_sin_row + w_sin_row + + # #################################################################### + # Load the left and right half of q and k for the current + # program instance (i.e. for the current token) separately + # #################################################################### + # left half of the head + first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2) + first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2) + q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype) + k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype) + + # right half of the head + second_half_q_offsets = first_half_q_offsets + (hd // 2) + second_half_k_offsets = first_half_k_offsets + (hd // 2) + second_q_mask = first_q_mask + second_k_mask = first_k_mask + q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype) + k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype) + + if not BACKWARD_PASS: + # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] + new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row + tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) + new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row + tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) + + new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row + tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) + new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row + tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) + else: + # with some math, we can get: + # dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin] + new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row + tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) + new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row + tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) + + new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row + tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) + new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row + tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) + + +def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section): + # transpose it back to the physical shape because Triton looks at the physical storage + # note: q and k are incontiguous before the transformation and will become contiguous after transpose + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + batch_size, seq_len, n_q_head, head_dim = q.shape + n_kv_head = k.shape[2] + pad_hd = triton.next_power_of_2(head_dim) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) + + n_row = batch_size * seq_len + + # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous + q = q.contiguous() + k = k.contiguous() + cos = cos.contiguous() + sin = sin.contiguous() + + _triton_qwen2vl_mrope[(n_row,)]( + q, + k, + cos, + sin, + seq_len, + batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + mrope_section[0], + mrope_section[1], + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=False, + ) + return q.transpose(1, 2), k.transpose(1, 2), cos, sin + + +def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section): + dq = dq.transpose(1, 2) + dk = dk.transpose(1, 2) + + batch_size, seq_len, n_q_head, head_dim = dq.shape + n_kv_head = dk.shape[2] + pad_hd = triton.next_power_of_2(head_dim) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) + + n_row = batch_size * seq_len + + # ensure dq and dk are contiguous + dq = dq.contiguous() + dk = dk.contiguous() + + # backward is similar to forward except swapping few ops + _triton_qwen2vl_mrope[(n_row,)]( + dq, + dk, + cos, + sin, + seq_len, + batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + mrope_section[0], + mrope_section[1], + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=True, + ) + return dq.transpose(1, 2), dk.transpose(1, 2) + + +class LigerQwen2VLMRopeFunction(torch.autograd.Function): + """ + Triton implementation of the Qwen2VL Multimodal Rotary Positional Embedding (M-RoPE) operation. + + Please find the corresponding HuggingFace implementation here: + https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py + """ + + @staticmethod + def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """ + q size: (bsz, n_q_head, seq_len, head_dim) + k size: (bsz, n_kv_head, seq_len, head_dim) + cos size: (3, bsz, seq_len, head_dim) + sin size: (3, bsz, seq_len, head_dim) + """ + q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section) + ctx.save_for_backward(cos, sin) + ctx.mrope_section = mrope_section + return q, k + + def backward(ctx, dq, dk): + """ + dq size: (bsz, n_q_head, seq_len, head_dim) + dk size: (bsz, n_kv_head, seq_len, head_dim) + cos size: (3, bsz, seq_len, head_dim) + sin size: (3, bsz, seq_len, head_dim) + """ + cos, sin = ctx.saved_tensors + mrope_section = ctx.mrope_section + dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section) + return dq, dk, None, None, None, None diff --git a/src/liger_kernel/ops/rms_norm.py b/src/liger_kernel/ops/rms_norm.py new file mode 100755 index 0000000000000000000000000000000000000000..e5cab72ea661351fadb5e4513e47dcfd303289ac --- /dev/null +++ b/src/liger_kernel/ops/rms_norm.py @@ -0,0 +1,654 @@ +""" +This file incorporates code from Unsloth licensed under the Apache License, Version 2.0. +See the original Unsloth repository at https://github.com/unslothai/unsloth. + +The following line +https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/rms_norm.py#L30 +is based on code from Unsloth, located at: +https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22 + +Modifications made by Yanning Chen, 2024. +""" + +import math +import operator + +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.utils import calculate_settings +from liger_kernel.ops.utils import compare_version +from liger_kernel.ops.utils import ensure_contiguous +from liger_kernel.ops.utils import get_npu_core_count +from liger_kernel.ops.utils import set_large_grf_mode +from liger_kernel.ops.utils import torch_to_triton_dtype +from liger_kernel.utils import is_npu_available + +if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available(): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import rsqrt + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import rsqrt +else: + from triton.language.math import rsqrt + + +_CASTING_MODE_NONE: tl.constexpr = tl.constexpr(-1) +_CASTING_MODE_LLAMA: tl.constexpr = tl.constexpr(0) +_CASTING_MODE_GEMMA: tl.constexpr = tl.constexpr(1) + + +@triton.jit +def _rms_norm_forward_kernel( + Y_ptr, + Y_row_stride, + X_ptr, + X_row_stride, + W_ptr, + W_row_stride, + RSTD_ptr, + RSTD_row_stride, + n_cols, + eps, + offset, + casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out + elementwise_affine: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N) + + Reference: + 1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html + 2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22 + 3. https://arxiv.org/pdf/1910.07467 + """ + + row_idx = tl.program_id(0).to(tl.int64) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + y_base = Y_ptr + row_idx * Y_row_stride + x_base = X_ptr + row_idx * X_row_stride + rstd_base = RSTD_ptr + row_idx * RSTD_row_stride + + X_row = tl.load(x_base + col_offsets, mask=mask, other=0) + X_row_dtype = X_row.dtype + if elementwise_affine: + W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0) + + # On Llama, only rstd is computed on fp32 + if casting_mode == _CASTING_MODE_LLAMA: + X_row = X_row.to(tl.float32) + + # Gemma computes everything on fp32, and then casts back the output to the original dtype + if casting_mode == _CASTING_MODE_GEMMA: + if elementwise_affine: + W_row = W_row.to(tl.float32) + X_row = X_row.to(tl.float32) + + if casting_mode == _CASTING_MODE_NONE: + eps = eps.to(X_row_dtype) + offset = offset.to(X_row_dtype) + + mean_square = tl.sum(X_row * X_row, axis=0) / n_cols + rstd = rsqrt(mean_square + eps) + + # We can save time by caching rms with minimal memory overhead + # because rms is much smaller compared to X_row, as rms is for each row. + # However, on the computation side, it can save 4 operations (*, sum, /, sqrt). + tl.store(rstd_base, rstd) + + X_row = X_row * rstd + + # On Llama, the multiplication with the weight is done on the original dtype + if casting_mode == _CASTING_MODE_LLAMA: + X_row = X_row.to(X_row_dtype) + + if elementwise_affine: + Y_row = X_row * (offset + W_row) + else: + Y_row = X_row + + if casting_mode == _CASTING_MODE_GEMMA: + Y_row = Y_row.to(X_row_dtype) + + tl.store(y_base + col_offsets, Y_row, mask=mask) + + +@triton.jit +def _rms_norm_backward_kernel( + dY_ptr, + dY_row_stride, + dX_ptr, + dX_row_stride, + X_ptr, + X_row_stride, + X_dtype: tl.constexpr, + W_ptr, + W_row_stride, + RSTD_ptr, + RSTD_row_stride, + dW_ptr, + dW_row_stride, + n_rows, + n_cols, + offset, + rows_per_program, + casting_mode: tl.constexpr, + elementwise_affine: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product + dw = sum(dy * (x / RMS)). summation over BxT dimension + """ + + row_block_id = tl.program_id(0).to(tl.int64) + row_start = row_block_id * rows_per_program + row_end = min((row_block_id + 1) * rows_per_program, n_rows) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + if elementwise_affine: + dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + + if elementwise_affine: + W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0) + W_row = W_row + offset + + for row_idx in range(row_start, row_end): + dy_base = dY_ptr + row_idx * dY_row_stride + dx_base = dX_ptr + row_idx * dX_row_stride + + x_base = X_ptr + row_idx * X_row_stride + rstd_base = RSTD_ptr + row_idx * RSTD_row_stride + + dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0) + X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0) + + # Get cached rms + rstd_row = tl.load(rstd_base) + + X_row = X_row.to(tl.float32) + + # Different bacward graphs for different casting modes + if casting_mode == _CASTING_MODE_LLAMA: + if elementwise_affine: + m = (dY_row * W_row).to(tl.float32) + else: + m = dY_row.to(tl.float32) + + elif casting_mode == _CASTING_MODE_GEMMA: + dY_row = dY_row.to(tl.float32) + if elementwise_affine: + m = dY_row * W_row + else: + m = dY_row + else: + if elementwise_affine: + m = dY_row * W_row + else: + m = dY_row + + dX_row = rstd_row * m + + dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row) + + if elementwise_affine: + # calculate the gradient of W + if casting_mode == _CASTING_MODE_LLAMA: + dW_row += dY_row * (X_row * rstd_row).to(X_dtype) + else: + # here X_row is already in fp32 (see previous if block) + dW_row += dY_row * (X_row * rstd_row) + + tl.store(dx_base + col_offsets, dX_row.to(X_dtype), mask=mask) + + if elementwise_affine: + tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask) + + +@triton.jit +def _block_rms_norm_forward_kernel( + Y_ptr, + Y_row_stride, + X_ptr, + X_row_stride, + W_ptr, + W_row_stride, + RSTD_ptr, + RSTD_row_stride, + n_rows, + n_cols, + eps, + offset, + casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out + elementwise_affine: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_ROW: tl.constexpr, +): + """ + y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N) + + Reference: + 1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html + 2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22 + 3. https://arxiv.org/pdf/1910.07467 + """ + + row_idx = tl.program_id(0) * BLOCK_ROW + tl.arange(0, BLOCK_ROW) + col_offsets = tl.arange(0, BLOCK_SIZE) + row_mask = row_idx < n_rows + col_mask = col_offsets < n_cols + + X_row = tl.load( + X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :], + mask=row_mask[:, None] & col_mask[None, :], + other=0, + ) + X_row_dtype = X_row.dtype + if elementwise_affine: + W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0) + + # On Llama, only rstd is computed on fp32 + if casting_mode == _CASTING_MODE_LLAMA: + X_row = X_row.to(tl.float32) + + # Gemma computes everything on fp32, and then casts back the output to the original dtype + if casting_mode == _CASTING_MODE_GEMMA: + if elementwise_affine: + W_row = W_row.to(tl.float32) + X_row = X_row.to(tl.float32) + + if casting_mode == _CASTING_MODE_NONE: + eps = eps.to(X_row_dtype) + offset = offset.to(X_row_dtype) + + mean_square = tl.sum(X_row * X_row, axis=1) / n_cols + rstd = rsqrt(mean_square + eps) + + # We can save time by caching rms with minimal memory overhead + # because rms is much smaller compared to X_row, as rms is for each row. + # However, on the computation side, it can save 4 operations (*, sum, /, sqrt). + tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd, row_mask) + + X_row = X_row * rstd[:, None] + + # On Llama, the multiplication with the weight is done on the original dtype + if casting_mode == _CASTING_MODE_LLAMA: + X_row = X_row.to(X_row_dtype) + + if elementwise_affine: + Y_row = X_row * (offset + W_row)[None, :] + else: + Y_row = X_row + + if casting_mode == _CASTING_MODE_GEMMA: + Y_row = Y_row.to(X_row_dtype) + + tl.store( + Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :], + Y_row, + mask=row_mask[:, None] & col_mask[None, :], + ) + + +@triton.jit +def _block_rms_norm_backward_kernel( + dY_ptr, + dY_row_stride, + dX_ptr, + dX_row_stride, + X_ptr, + X_row_stride, + X_dtype: tl.constexpr, + W_ptr, + W_row_stride, + RSTD_ptr, + RSTD_row_stride, + dW_ptr, + dW_row_stride, + n_rows, + n_cols, + offset, + casting_mode: tl.constexpr, + elementwise_affine: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_ROW: tl.constexpr, +): + """ + dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product + dw = sum(dy * (x / RMS)). summation over BxT dimension + """ + + pid = tl.program_id(0).cast(tl.int64) + NUM_SMS = tl.num_programs(0) + + col_offsets = tl.arange(0, BLOCK_SIZE) + col_mask = col_offsets < n_cols + + if elementwise_affine: + dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + + W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0) + W_row = W_row + offset + + for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW): + row_idx = start + tl.arange(0, BLOCK_ROW) + row_mask = row_idx < n_rows + dY_row = tl.load( + dY_ptr + row_idx[:, None] * dY_row_stride + col_offsets[None, :], + mask=row_mask[:, None] & col_mask[None, :], + other=0.0, + ) + X_row = tl.load( + X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :], + mask=row_mask[:, None] & col_mask[None, :], + other=0.0, + ) + + # Get cached rms + rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride, row_mask) + + X_row = X_row.to(tl.float32) + + # Different bacward graphs for different casting modes + if casting_mode == _CASTING_MODE_LLAMA: + if elementwise_affine: + m = (dY_row * W_row[None, :]).to(tl.float32) + else: + m = dY_row.to(tl.float32) + + elif casting_mode == _CASTING_MODE_GEMMA: + dY_row = dY_row.to(tl.float32) + if elementwise_affine: + m = dY_row * W_row[None, :] + else: + m = dY_row + else: + if elementwise_affine: + m = dY_row * W_row[None, :] + else: + m = dY_row + + dX_row = rstd_row[:, None] * m + + dX_row += (rstd_row[:, None]) * ( + -(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row + ) + + if elementwise_affine: + if casting_mode == _CASTING_MODE_LLAMA: + # TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0 + dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0) + else: + # here X_row is already in fp32 (see previous if block) + dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0) + + tl.store( + dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :], + dX_row, + mask=row_mask[:, None] & col_mask[None, :], + ) + + if elementwise_affine: + tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask) + + +_str_to_casting_mode = { + "llama": _CASTING_MODE_LLAMA.value, + "gemma": _CASTING_MODE_GEMMA.value, + "none": _CASTING_MODE_NONE.value, +} + + +def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode): + if not isinstance(casting_mode, int): + assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}" + casting_mode = _str_to_casting_mode[casting_mode] + else: + assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}" + + shape = X.shape + dim = shape[-1] + X = X.view(-1, dim) + n_rows, n_cols = X.shape + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) + # RSTD is to cache rstd for each row + # RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode + rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype + RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device) + + if W is not None: + # Check constraints. + assert X.shape[1] == W.shape[0], ( + "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]" + ) + elementwise_affine = True + else: + elementwise_affine = False + + # XPU-specific optimization + kernel_args = {} + if X.device.type == "xpu": + set_large_grf_mode(kernel_args) + if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode: + _rms_norm_forward_kernel[(n_rows,)]( + Y, + Y.stride(0), + X, + X.stride(0), + W, + W.stride(0) if elementwise_affine else 0, + RSTD, + RSTD.stride(0), + n_cols, + eps, + offset, + casting_mode, + elementwise_affine=elementwise_affine, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + **kernel_args, # XPU-specific optimization + ) + else: + BLOCK_ROW = 16 + kernel_args["BLOCK_ROW"] = BLOCK_ROW + _block_rms_norm_forward_kernel[(triton.cdiv(n_rows, BLOCK_ROW),)]( + Y, + Y.stride(0), + X, + X.stride(0), + W, + W.stride(0) if elementwise_affine else 0, + RSTD, + RSTD.stride(0), + n_rows, + n_cols, + eps, + offset, + casting_mode, + elementwise_affine=elementwise_affine, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + **kernel_args, # XPU-specific optimization + ) + return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode + + +def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place, row_mode): + shape = dY.shape + dim = shape[-1] + dY = dY.view(-1, dim) + n_rows, n_cols = dY.shape + + sm_count = 1 + if X.device.type == "cuda": + sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count + elif X.device.type == "xpu": + sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count + elif X.device.type == "npu": + sm_count = get_npu_core_count() + + if W is not None: + # fp32 for numerical stability especially. + _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device) + elementwise_affine = True + else: + _dW = None + elementwise_affine = False + + if n_cols > BLOCK_SIZE: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + rows_per_program = math.ceil(n_rows / sm_count) + grid = (sm_count,) + + if in_place is True: + dX = dY + else: + dX = torch.zeros_like(dY) + + # XPU-specific optimization + kernel_args = {} + if X.device.type == "xpu": + set_large_grf_mode(kernel_args) + + if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode: + _rms_norm_backward_kernel[grid]( + dY, + dY.stride(0), + dX, + dX.stride(0), + X, + X.stride(0), + torch_to_triton_dtype[X.dtype], + W, + W.stride(0) if elementwise_affine else 0, + RSTD, + RSTD.stride(0), + _dW, + _dW.stride(0) if elementwise_affine else 0, + n_rows, + n_cols, + offset, + rows_per_program, + casting_mode, + elementwise_affine=elementwise_affine, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + **kernel_args, # XPU-specific optimization + ) + else: + BLOCK_ROW = 16 + kernel_args["BLOCK_ROW"] = BLOCK_ROW + _block_rms_norm_backward_kernel[grid]( + dY, + dY.stride(0), + dX, + dX.stride(0), + X, + X.stride(0), + torch_to_triton_dtype[X.dtype], + W, + W.stride(0) if elementwise_affine else 0, + RSTD, + RSTD.stride(0), + _dW, + _dW.stride(0) if elementwise_affine else 0, + n_rows, + n_cols, + offset, + casting_mode, + elementwise_affine=elementwise_affine, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + **kernel_args, # XPU-specific optimization + ) + dX = dX.view(*shape) + + if elementwise_affine: + dW = _dW.sum(dim=0).to(W.dtype) + else: + dW = None + + return dX, dW + + +class LigerRMSNormFunction(torch.autograd.Function): + """ + Performs RMSNorm (Root Mean Square Normalization), which normalizes the input tensor `X` using the + weight tensor `W`, with an optional offset and casting mode. + + Some models use an 'offset' to shift the weight tensor `W` by a constant value. For example, Gemma + uses an offset of 1.0, so the computation becomes `(X / RMS(X)) * (W + 1.0)` instead of the usual + `(X / RMS(X)) * W`. You can pass the offset value as an argument to the forward function. + + In addition, different models cast their inputs at different places during RMSNorm computation. For + example, Gemma casts everything to fp32 nefore starting the computation, while Llama casts only the + inverse RMS to fp32. You can specify the casting mode using the `casting_mode` argument. We currently + support the following casting modes (they match HuggingFace Transformers' implementations): + - 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32. + - 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype. + - 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation. + + `in_place` option means whether to in_place modify dY to store dX. This is default to `True` to save memory. However, under certain cases, it can produce incorrect inputs. + For example, gemma2 uses two rmsnorm sequentially with residual in between. The resesidual part needs dY so it cannot be modified in-place. + Therefore, for the patching of RMSNorm in gemma2, we set `in_place` to `False` + """ + + @staticmethod + @ensure_contiguous + def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row_mode=None): + """ + X: (B, T, H) or (BxT, H) + W: (H,) + """ + if isinstance(X, torch.distributed.tensor.DTensor): + # Input tensor is output of a tensor parallel module and + # needs to be gathered to a local tensor to compute + # RMSE layer norm on each TP worker. + # TODO: support CP. + X = X.full_tensor() + + Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode) + ctx.offset = offset + ctx.casting_mode = casting_mode + ctx.in_place = in_place + ctx.row_mode = row_mode + ctx.BLOCK_SIZE = BLOCK_SIZE + ctx.num_warps = num_warps + ctx.elementwise_affine = W is not None + if W is not None: + ctx.save_for_backward(X, W, RSTD) + else: + ctx.save_for_backward(X, RSTD) + return Y + + @staticmethod + @ensure_contiguous + def backward(ctx, dY): + """ + Y: (B, T, H) or (BxT, H) + """ + if ctx.elementwise_affine: + X, W, RSTD = ctx.saved_tensors + else: + X, RSTD = ctx.saved_tensors + W = None + + if isinstance(dY, torch.distributed.tensor.DTensor): + # Gradients are output of a tensor parallel module and + # needs to be gathered to a local tensor for computing RMSE layer. + # TODO: support CP. + dY = dY.full_tensor() + + dX, dW = rms_norm_backward( + dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode + ) + return dX, dW, None, None, None, None, None diff --git a/src/liger_kernel/ops/rope.py b/src/liger_kernel/ops/rope.py new file mode 100755 index 0000000000000000000000000000000000000000..bd8ded7308eeddb7762752f4adc95e68903a77ec --- /dev/null +++ b/src/liger_kernel/ops/rope.py @@ -0,0 +1,239 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _triton_rope( + q_ptr, + q_row_stride, + k_ptr, + k_row_stride, + cos, + cos_row_stride, + sin, + sin_row_stride, + sl, + bs: tl.constexpr, + cos_bs: tl.constexpr, + n_qh: tl.constexpr, + n_kh: tl.constexpr, + hd: tl.constexpr, + pad_n_qh: tl.constexpr, + pad_n_kh: tl.constexpr, + pad_hd: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BACKWARD_PASS: tl.constexpr = False, +): + # q size: (bsz, seq_len, num_q_heads, head_dim) + # q stride: (seq_len * num_q_heads * head_dim, num_q_heads * head_dim, head_dim, 1) + # k size: (bsz, seq_len, num_kv_heads, head_dim) + # k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1) + + # cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + # stride: (seq_len * head_dim, head_dim, 1) + pid = tl.program_id(0).to(tl.int64) + + # locate start address + q_ptr = q_ptr + pid * q_row_stride + k_ptr = k_ptr + pid * k_row_stride + + # #################################################################### + # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position + # m of this program instance + # #################################################################### + + # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which + # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension + # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index + # and pid % sl to get the sequence index. + # 2. We only need the left half of cos and sin matrix because the right half is just + # a clone of the left half. + batch_idx = pid // sl + cos_row_idx = pid % sl + cos = cos + tl.where( + cos_bs == 1, + cos_row_idx * cos_row_stride, + batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride, + ) + sin = sin + tl.where( + cos_bs == 1, + cos_row_idx * sin_row_stride, + batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride, + ) + + cos_offsets = tl.arange(0, pad_hd // 2) + cos_mask = cos_offsets < hd // 2 + cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0) + sin_row = tl.load(sin + cos_offsets, mask=cos_mask, other=0) + + # #################################################################### + # Load the left and right half of q and k for the current + # program instance (i.e. for the current token) separately + # #################################################################### + # left half of the head + first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2) + first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2) + q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype) + k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype) + + # right half of the head + second_half_q_offsets = first_half_q_offsets + (hd // 2) + second_half_k_offsets = first_half_k_offsets + (hd // 2) + second_q_mask = first_q_mask + second_k_mask = first_k_mask + q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype) + k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype) + + if not BACKWARD_PASS: + # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] + new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row + tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) + new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row + tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) + + new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row + tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) + new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row + tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) + else: + # with some math, we can get: + # dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin] + new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row + tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) + new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row + tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) + + new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row + tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) + new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row + tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) + + +def rope_forward(q, k, cos, sin): + # transpose it back to the physical shape because Triton looks at the physical storage + # note: q and k are incontiguous before the transformation and will become contiguous after transpose + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + batch_size, seq_len, n_q_head, head_dim = q.shape + n_kv_head = k.shape[2] + pad_hd = triton.next_power_of_2(head_dim) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) + + n_row = batch_size * seq_len + + # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous + q = q.contiguous() + k = k.contiguous() + cos = cos.contiguous() + sin = sin.contiguous() + cos_batch_size = cos.shape[0] + + _triton_rope[(n_row,)]( + q, + q.stride(1), + k, + k.stride(1), + cos, + cos.stride(-2), + sin, + sin.stride(-2), + seq_len, + batch_size, + cos_batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=False, + ) + return q.transpose(1, 2), k.transpose(1, 2), cos, sin + + +def rope_backward(dq, dk, cos, sin): + dq = dq.transpose(1, 2) + dk = dk.transpose(1, 2) + + batch_size, seq_len, n_q_head, head_dim = dq.shape + cos_batch_size = cos.shape[0] + n_kv_head = dk.shape[2] + pad_hd = triton.next_power_of_2(head_dim) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) + + n_row = batch_size * seq_len + + # ensure dq and dk are contiguous + dq = dq.contiguous() + dk = dk.contiguous() + + # backward is similar to forward except swapping few ops + _triton_rope[(n_row,)]( + dq, + dq.stride(1), + dk, + dk.stride(1), + cos, + cos.stride(-2), + sin, + sin.stride(-2), + seq_len, + batch_size, + cos_batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=True, + ) + return dq.transpose(1, 2), dk.transpose(1, 2) + + +class LigerRopeFunction(torch.autograd.Function): + """ + Triton implementation of the Rotary Positional Embedding (RoPE) operation. Please note that + this implements the HuggingFace Llama & Mistral version, whose rotation matrix is slightly different + than the original RoPE paper. + + Please find the corresponding HuggingFace implementation here: + https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/llama/modeling_llama.py#L184 + + For more details about the rotation matrix used here, please refer to: + https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509/2 + """ + + @staticmethod + def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """ + q size: (bsz, n_q_head, seq_len, head_dim) + k size: (bsz, n_kv_head, seq_len, head_dim) + cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + """ + q, k, cos, sin = rope_forward(q, k, cos, sin) + ctx.save_for_backward(cos, sin) + return q, k + + def backward(ctx, dq, dk): + """ + dq size: (bsz, n_q_head, seq_len, head_dim) + dk size: (bsz, n_kv_head, seq_len, head_dim) + cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + """ + + cos, sin = ctx.saved_tensors + dq, dk = rope_backward(dq, dk, cos, sin) + return dq, dk, None, None, None, None diff --git a/src/liger_kernel/ops/softmax.py b/src/liger_kernel/ops/softmax.py new file mode 100755 index 0000000000000000000000000000000000000000..15db6cdda36e442007d8c380ef13c3c72293abb1 --- /dev/null +++ b/src/liger_kernel/ops/softmax.py @@ -0,0 +1,201 @@ +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.utils import calculate_settings +from liger_kernel.ops.utils import ensure_contiguous + + +@triton.jit +def _softmax_single_block_forward_kernel( + Y_ptr, + Y_row_stride, + X_ptr, + X_row_stride, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + row_id = tl.program_id(0) + offs = tl.arange(0, BLOCK_SIZE) + mask = offs < n_cols + + x = tl.load(X_ptr + row_id * X_row_stride + offs, mask=mask, other=-float("inf"), cache_modifier=".ca") + m = tl.max(x, axis=0) + e = tl.exp(x - m) + d = tl.sum(e, axis=0) + y = e / d + tl.store(Y_ptr + row_id * Y_row_stride + offs, y, mask=mask, cache_modifier=".cs") + + +@triton.jit +def _softmax_multi_block_forward_kernel( + Y_ptr, + Y_row_stride, + X_ptr, + X_row_stride, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + row_id = tl.program_id(0) + offs = tl.arange(0, BLOCK_SIZE) + + m = tl.float32(-float("inf")) + d = tl.float32(0.0) + for start in tl.range(0, n_cols, BLOCK_SIZE): + idx = start + offs + mask = idx < n_cols + xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca") + blk_max = tl.max(xblk, axis=0) + new_m = tl.max(m, blk_max) + d = d * tl.exp(m - new_m) + tl.sum(tl.exp(xblk - new_m), axis=0) + m = new_m + + for start in tl.range(0, n_cols, BLOCK_SIZE): + idx = start + offs + mask = idx < n_cols + xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca") + yblk = tl.exp(xblk - m) / d + tl.store(Y_ptr + row_id * Y_row_stride + idx, yblk, mask=mask, cache_modifier=".cs") + + +@triton.jit +def _softmax_single_block_backward_kernel( + dy_ptr, + dy_stride, + y_ptr, + y_stride, + dx_ptr, + dx_stride, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + row_id = tl.program_id(0) + offs = tl.arange(0, BLOCK_SIZE) + mask = offs < n_cols + + dy = tl.load(dy_ptr + row_id * dy_stride + offs, mask=mask, other=0.0) + y = tl.load(y_ptr + row_id * y_stride + offs, mask=mask, other=0.0, cache_modifier=".ca") + dot = tl.sum(dy * y, axis=0) + dx = y * (dy - dot) + tl.store(dx_ptr + row_id * dx_stride + offs, dx, mask=mask, cache_modifier=".wb") + + +@triton.jit +def _softmax_multi_block_backward_kernel( + dy_ptr, + dy_stride, + y_ptr, + y_stride, + dx_ptr, + dx_stride, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + row_id = tl.program_id(0) + offs = tl.arange(0, BLOCK_SIZE) + acc = tl.float32(0.0) + + for start in tl.range(0, n_cols, BLOCK_SIZE): + idx = start + offs + mask = idx < n_cols + dy_blk = tl.load(dy_ptr + row_id * dy_stride + idx, mask=mask, other=0.0) + y_blk = tl.load(y_ptr + row_id * y_stride + idx, mask=mask, other=0.0, cache_modifier=".ca") + acc += tl.sum(dy_blk * y_blk, axis=0) + + for start in tl.range(0, n_cols, BLOCK_SIZE): + idx = start + offs + mask = idx < n_cols + dy_blk = tl.load(dy_ptr + row_id * dy_stride + idx, mask=mask, other=0.0) + y_blk = tl.load(y_ptr + row_id * y_stride + idx, mask=mask, other=0.0, cache_modifier=".ca") + dx_blk = y_blk * (dy_blk - acc) + tl.store(dx_ptr + row_id * dx_stride + idx, dx_blk, mask=mask, cache_modifier=".wb") + + +def _softmax_forward(x: torch.Tensor) -> Tuple[torch.Tensor, int, int, bool]: + *batch, n_cols = x.shape + x2d = x.contiguous().view(-1, n_cols) + n_rows = x2d.shape[0] + + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + y2d = torch.empty_like(x2d) + + if n_cols <= BLOCK_SIZE: + _softmax_single_block_forward_kernel[(n_rows,)]( + y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps + ) + multi_block_launch = False + else: + _softmax_multi_block_forward_kernel[(n_rows,)]( + y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps + ) + multi_block_launch = True + + return y2d.view(*batch, n_cols), BLOCK_SIZE, num_warps, multi_block_launch + + +def _softmax_backward( + dy: torch.Tensor, + y: torch.Tensor, + BLOCK_SIZE: int, + num_warps: int, + multi_block_launch: bool, +) -> torch.Tensor: + *batch, n_cols = dy.shape + dy2d = dy.contiguous().view(-1, n_cols) + y2d = y.contiguous().view(-1, n_cols) + n_rows = dy2d.shape[0] + dx2d = torch.empty_like(dy2d) + + if not multi_block_launch and n_cols <= BLOCK_SIZE: + _softmax_single_block_backward_kernel[(n_rows,)]( + dy2d, + dy2d.stride(0), + y2d, + y2d.stride(0), + dx2d, + dx2d.stride(0), + n_cols, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + else: + _softmax_multi_block_backward_kernel[(n_rows,)]( + dy2d, + dy2d.stride(0), + y2d, + y2d.stride(0), + dx2d, + dx2d.stride(0), + n_cols, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + + return dx2d.view(*batch, n_cols) + + +class LigerSoftmaxFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, input_: torch.Tensor): + y, BLOCK_SIZE, num_warps, multi_block_launch = _softmax_forward(input_) + ctx.save_for_backward(y) + ctx.BLOCK_SIZE = BLOCK_SIZE + ctx.num_warps = num_warps + ctx.multi_block_launch = multi_block_launch + return y + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output): + (y,) = ctx.saved_tensors + dx = _softmax_backward( + grad_output, + y, + ctx.BLOCK_SIZE, + ctx.num_warps, + ctx.multi_block_launch, + ) + return dx diff --git a/src/liger_kernel/ops/sparsemax.py b/src/liger_kernel/ops/sparsemax.py new file mode 100755 index 0000000000000000000000000000000000000000..065785a2aa0de756a11788b5c3b1f9e2464fd0c4 --- /dev/null +++ b/src/liger_kernel/ops/sparsemax.py @@ -0,0 +1,177 @@ +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.utils import calculate_settings +from liger_kernel.ops.utils import ensure_contiguous + + +@triton.jit +def _sparsemax_forward_kernel( + x_ptr, + x_stride_row, + sorted_x_ptr, + sorted_x_stride_row, + o_ptr, + o_stride_row, + n_cols, + BLOCK_SIZE: tl.constexpr, + num_warps: tl.constexpr, +): + pid_row = tl.program_id(0) + ptr_x_data_row = x_ptr + pid_row * x_stride_row + ptr_sorted_x_data_row = sorted_x_ptr + pid_row * sorted_x_stride_row + ptr_output_row = o_ptr + pid_row * o_stride_row + + offs = tl.arange(0, BLOCK_SIZE) + mask = offs < n_cols + + z_sorted_block = tl.load( + ptr_sorted_x_data_row + offs, + mask=mask, + other=-float("inf"), + cache_modifier=".cg", + ).to(tl.float32) + + z_valid = tl.where(mask, z_sorted_block, 0.0) + cssv = tl.cumsum(z_valid, 0) + + r = (offs + 1).to(tl.float32) + t_vec = (cssv - 1.0) / r + + support = (z_sorted_block > t_vec) & mask + + k_int = tl.sum(support.to(tl.int32), 0) + k_clamped_int = tl.maximum(k_int, 1) + k = k_clamped_int.to(tl.float32) + + s = tl.sum(tl.where(support, z_sorted_block, 0.0), 0) + + tau = (s - 1.0) / k + + x_block = tl.load( + ptr_x_data_row + offs, + mask=mask, + other=0.0, + cache_modifier=".cg", + ).to(tl.float32) + + y = tl.maximum(x_block - tau, 0.0) + + tl.store( + ptr_output_row + offs, + y.to(ptr_output_row.dtype.element_ty), + mask=mask, + cache_modifier=".cs", + ) + + +@triton.jit +def _sparsemax_backward_kernel( + o_ptr, go_ptr, gi_ptr, stride, n_cols, BLOCK_SIZE: tl.constexpr, num_warps: tl.constexpr +): + row = tl.program_id(0) + o_row = o_ptr + row * stride + go_row = go_ptr + row * stride + gi_row = gi_ptr + row * stride + + offs = tl.arange(0, BLOCK_SIZE) + + supp_cnt = tl.zeros((), tl.float32) + go_sum = tl.zeros((), tl.float32) + + for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)): + offs_iter = i * BLOCK_SIZE + offs + mask_iter = offs_iter < n_cols + o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32) + go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32) + supp = o_val > 0.0 + go_sum += tl.sum(tl.where(supp, go_val, 0.0)) + supp_cnt += tl.sum(supp.to(tl.float32)) + + for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)): + offs_iter = i * BLOCK_SIZE + offs + mask_iter = offs_iter < n_cols + o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32) + go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32) + supp = o_val > 0.0 + gi_val = tl.where( + supp, + go_val - tl.cast(go_sum / tl.maximum(supp_cnt, 1e-6), gi_row.dtype.element_ty).to(tl.float32), + 0.0, + ) + tl.store(gi_row + offs_iter, gi_val.to(gi_row.dtype.element_ty), mask=mask_iter, cache_modifier=".wb") + + +def _sparsemax_forward(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, torch.Tensor]: + if dim < 0: + dim += x.dim() + x_sw = x.transpose(dim, -1).contiguous() + n_cols = x_sw.size(-1) + n_rows = x_sw.numel() // n_cols + x_flat = x_sw.view(n_rows, n_cols) + x_sorted_flat = torch.sort(x_flat.float(), dim=-1, descending=True).values + + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + out_flat = torch.empty_like(x_flat) + grid = (n_rows,) + _sparsemax_forward_kernel[grid]( + x_flat, + x_flat.stride(0), + x_sorted_flat, + x_sorted_flat.stride(0), + out_flat, + out_flat.stride(0), + n_cols, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + + y = out_flat.view_as(x_sw).transpose(dim, -1) + return y, out_flat + + +def _sparsemax_backward( + grad_out: torch.Tensor, + out_flat: torch.Tensor, + dim: int, +) -> torch.Tensor: + grad_sw = grad_out.transpose(dim, -1).contiguous() + n_cols = grad_sw.size(-1) + n_rows = grad_sw.numel() // n_cols + go_flat = grad_sw.view(n_rows, n_cols) + + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + dx_flat = torch.empty_like(go_flat) + grid = (n_rows,) + _sparsemax_backward_kernel[grid]( + out_flat, + go_flat, + dx_flat, + out_flat.stride(0), + n_cols, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + + dx = dx_flat.view_as(grad_sw).transpose(dim, -1) + return dx + + +class LigerSparsemaxFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, x: torch.Tensor, dim: int): + y, out_flat = _sparsemax_forward(x, dim) + ctx.save_for_backward(out_flat) + ctx.dim = dim + return y + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_out: torch.Tensor): + (out_flat,) = ctx.saved_tensors + dx = _sparsemax_backward(grad_out, out_flat, ctx.dim) + return dx, None diff --git a/src/liger_kernel/ops/swiglu.py b/src/liger_kernel/ops/swiglu.py new file mode 100755 index 0000000000000000000000000000000000000000..675033683e733b9d15ecafe1dc8c08e70e641c9d --- /dev/null +++ b/src/liger_kernel/ops/swiglu.py @@ -0,0 +1,151 @@ +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.utils import calculate_settings +from liger_kernel.ops.utils import ensure_contiguous + + +@triton.jit +def silu(x): + return x * tl.sigmoid(x) + + +@triton.jit +def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr): + program_id = tl.program_id(0).to(tl.int64) + + # locate start index + a_ptr += program_id * stride + b_ptr += program_id * stride + c_ptr += program_id * stride + + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + # sigmoid requires type float32 + a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32) + b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0) + c_row = silu(a_row).cast(b_row.dtype) * b_row + tl.store(c_ptr + col_offsets, c_row, mask=mask) + + +@triton.jit +def _swiglu_backward_kernel(dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr): + program_id = tl.program_id(0).to(tl.int64) + + # locate start index + dc_ptr += program_id * stride + a_ptr += program_id * stride + b_ptr += program_id * stride + + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + dc_row = tl.load(dc_ptr + col_offsets, mask=mask, other=0) + # sigmoid requires type float32 + a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32) + b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0) + + # recomputation to save memory + sig_a = tl.sigmoid(a_row) + silu_a = a_row * sig_a + db_row = dc_row * silu_a + da_row = dc_row * (silu_a * (1 - sig_a) + sig_a) * b_row + + tl.store(a_ptr + col_offsets, da_row, mask=mask) + tl.store(b_ptr + col_offsets, db_row, mask=mask) + + +def swiglu_forward(a, b): + ori_shape = a.shape + + n_cols = ori_shape[-1] + a = a.view(-1, n_cols) + b = b.view(-1, n_cols) + c = torch.empty_like(a) + n_rows = a.shape[0] + + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + _swiglu_forward_kernel[(n_rows,)]( + a, + b, + c, + c.stride(-2), + n_cols=n_cols, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + return a, b, c.view(*ori_shape) + + +def swiglu_backward(a, b, dc): + ori_shape = dc.shape + n_cols = ori_shape[-1] + dc = dc.view(-1, n_cols) + n_rows = dc.shape[0] + + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + _swiglu_backward_kernel[(n_rows,)]( + dc, + a, + b, + dc.stride(-2), + n_cols=n_cols, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + return a.view(*ori_shape), b.view(*ori_shape) + + +class LigerSiLUMulFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, a, b): + if isinstance(a, torch.distributed.tensor.DTensor) or isinstance(b, torch.distributed.tensor.DTensor): + device_mesh, placements = ( + (a.device_mesh, a.placements) + if isinstance(a, torch.distributed.tensor.DTensor) + else (b.device_mesh, b.placements) + ) + + # Assume that full tensors are gathered before and identical across + # the associated process groups. + if not isinstance(a, torch.distributed.tensor.DTensor): + a = torch.distributed.tensor.distribute_tensor(a, device_mesh=device_mesh, placements=placements) + if not isinstance(b, torch.distributed.tensor.DTensor): + b = torch.distributed.tensor.distribute_tensor(b, device_mesh=device_mesh, placements=placements) + a_local, b_local, c_local = swiglu_forward(a.to_local(), b.to_local()) + ctx.save_for_backward(a_local, b_local) + ctx.dtensor_metadata = (device_mesh, placements) + return torch.distributed.tensor.DTensor.from_local(c_local, device_mesh, placements) + else: + a, b, c = swiglu_forward(a, b) + ctx.save_for_backward(a, b) + ctx.dtensor_metadata = None + return c + + @staticmethod + @ensure_contiguous + def backward(ctx, dc): + a, b = ctx.saved_tensors + if ctx.dtensor_metadata is not None: + device_mesh, placements = ctx.dtensor_metadata + + # Assume that full tensors are gathered before and identical across + # the associated process groups. + dc_local = ( + dc.to_local() + if isinstance(dc, torch.distributed.tensor.DTensor) + else torch.distributed.tensor.distribute_tensor(dc, device_mesh=device_mesh, placements=placements) + ) + a_local, b_local = swiglu_backward(a, b, dc_local) + return ( + torch.distributed.tensor.DTensor.from_local(a_local, device_mesh, placements), + torch.distributed.tensor.DTensor.from_local(b_local, device_mesh, placements), + ) + + a, b = swiglu_backward(a, b, dc) + return a, b diff --git a/src/liger_kernel/ops/tiled_mlp.py b/src/liger_kernel/ops/tiled_mlp.py new file mode 100755 index 0000000000000000000000000000000000000000..2c1943c3aaf7f0e37f51dfef4ab5f8950a328790 --- /dev/null +++ b/src/liger_kernel/ops/tiled_mlp.py @@ -0,0 +1,136 @@ +import math + +from typing import Callable +from typing import List +from typing import Optional + +import torch + +from liger_kernel.ops.utils import ensure_contiguous + + +class LigerTiledMLPFunction(torch.autograd.Function): + """ + Based on DeepSpeed's TiledMLP: + https://github.com/deepspeedai/DeepSpeed/blob/v0.18.2/deepspeed/runtime/sequence_parallel/ulysses_sp.py#L838 + + Perform a tiled MLP computation to massively reduce memory usage needed to compute MLP + when using very long sequence lengths. + + This module re-computes `forward` in the `backward`. So the `forward` occurs twice each iteration. + And if you're using activation checkpointing it then occurs thrice. + + Args: + fn: the function to call on sharded inputs (e.g., mlp.forward) + mlp_module: the MLP nn.Module object + x: the input to MLP.forward (hidden_states) + shards: how many shards to use + compute_params: a list of weights engaged in the compute + + Returns: + the computed hidden_states + """ + + @staticmethod + @ensure_contiguous + def forward( + ctx, + fn: Callable, + mlp_module: torch.nn.Module, + x: torch.Tensor, + shards: int, + compute_params: Optional[List[torch.nn.Parameter]] = None, + ) -> torch.Tensor: + ctx.fn = fn + ctx.mlp_module = mlp_module + ctx.shards = shards + ctx.save_for_backward(x) + + # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts) + x_shards = list(torch.chunk(x, chunks=shards, dim=-2)) + with torch.no_grad(): + output_shards = [fn(mlp_module, x_shard) for x_shard in x_shards] + output_unsharded = torch.cat(output_shards, dim=-2) + + return output_unsharded + + @staticmethod + @ensure_contiguous + def backward(ctx, *grads) -> tuple: + fn = ctx.fn + (x,) = ctx.saved_tensors + mlp_module = ctx.mlp_module + shards = ctx.shards + + x_requires_grad = x.requires_grad + x = x.detach() + # detach() unsets x.requires_grad, so restore it + x.requires_grad_(x_requires_grad) + + # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts) + hidden_size = x.shape[-1] + x_shape_orig = x.shape + + # flatten bs+seqlen to avoid having stride issues when narrowing into seqlen w/ bs>1 + x = x.view(-1, hidden_size) + incoming_grad = grads[0].view(-1, hidden_size) + x_grad = torch.zeros_like(x) + + x_shards = list(torch.chunk(x, chunks=shards, dim=0)) + + for i, x_shard in enumerate(x_shards): + x_shard.requires_grad_(x_requires_grad) + + # if seqlen is not exactly divisible by shards the last step will be shorter than shard_step + shard_step = x_shards[i].shape[0] + shard_offset = i * x_shards[0].shape[0] + + x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard) + incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step).view_as(x_shard) + + with torch.enable_grad(): + output = fn(mlp_module, x_shard) + torch.autograd.backward(output, incoming_grad_shard) + + # unflatten + x_grad = x_grad.view(x_shape_orig) + + return (None, None, x_grad, None, None) + + +def apply_tiled_mlp( + fn: Callable, + mlp_module: torch.nn.Module, + x: torch.Tensor, + num_shards: Optional[int] = None, + compute_params: Optional[List[torch.nn.Parameter]] = None, +) -> torch.Tensor: + """ + Apply tiled MLP computation for memory efficiency. + + Args: + fn: the function to call on sharded inputs (e.g., lambda module, x: module(x)) + mlp_module: the MLP nn.Module object + x: the input tensor with shape [bs, seqlen, hidden_size] or [seqlen, hidden_size] + num_shards: number of shards to use. If None, automatically calculated as ceil(seqlen / hidden_size) + compute_params: list of parameters for DeepSpeed ZeRO optimization + + Returns: + output tensor with the same shape as input + """ + if num_shards is None: + # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] + hidden_size = x.shape[-1] + seqlen = x.shape[-2] + num_shards = math.ceil(seqlen / hidden_size) + + # Ensure num_shards is at least 1 + num_shards = max(1, num_shards) + + return LigerTiledMLPFunction.apply( + fn, + mlp_module, + x, + num_shards, + compute_params, + ) diff --git a/src/liger_kernel/ops/tvd.py b/src/liger_kernel/ops/tvd.py new file mode 100755 index 0000000000000000000000000000000000000000..154df000539438c4c6e0dde5810a7e76468dffc6 --- /dev/null +++ b/src/liger_kernel/ops/tvd.py @@ -0,0 +1,218 @@ +from typing import Literal +from typing import Optional + +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.utils import ensure_contiguous + +MAX_FUSED_SIZE = 65536 // 4 + +REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"] + +_REDUCTION_MODE_NONE = tl.constexpr(0) +_REDUCTION_MODE_SUM = tl.constexpr(1) +_REDUCTION_MODE_MEAN = tl.constexpr(2) +_REDUCTION_MODE_BATCHMEAN = tl.constexpr(3) + +_str_to_reduction_mode = { + "none": _REDUCTION_MODE_NONE.value, + "sum": _REDUCTION_MODE_SUM.value, + "mean": _REDUCTION_MODE_MEAN.value, + "batchmean": _REDUCTION_MODE_BATCHMEAN.value, +} + + +def get_num_warps(BLOCK_SIZE): + num_warps = 4 + if BLOCK_SIZE >= 32768: + num_warps = 32 + elif BLOCK_SIZE >= 8192: + num_warps = 16 + elif BLOCK_SIZE >= 2048: + num_warps = 8 + + return num_warps + + +@triton.jit +def _tv_distance_kernel( + p_ptr, + p_stride, + q_ptr, + q_stride, + loss_ptr, + loss_stride, + grads_ptr, + grads_stride, + label_ptr, + ignore_index: tl.constexpr, + n_cols, + scale, # pre-computed reduction scale for gradients (fused into kernel) + BLOCK_SIZE: tl.constexpr, + HAS_LABEL: tl.constexpr, + reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN, +): + pid = tl.program_id(0).to(tl.int64) + p_ptr += pid * p_stride + q_ptr += pid * q_stride + loss_ptr += pid * loss_stride + grads_ptr += pid * grads_stride + label_ptr += pid + + base_offsets = tl.arange(0, BLOCK_SIZE) + + if HAS_LABEL: + label = tl.load(label_ptr) + if label == ignore_index: + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + base_offsets + mask = offsets < n_cols + tl.store(grads_ptr + offsets, 0.0, mask=mask) + if reduction == _REDUCTION_MODE_NONE: + tl.store(loss_ptr + offsets, 0.0, mask=mask) + return + + loss_sum = 0.0 + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + base_offsets + mask = offsets < n_cols + + p = tl.load(p_ptr + offsets, mask=mask, other=0.0) + q = tl.load(q_ptr + offsets, mask=mask, other=0.0) + + # TVD(P || Q) = 0.5 * |P - Q| + tv_loss = 0.5 * tl.abs(p - q) + + # Fuse reduction scaling into gradient computation (eliminates separate Python division) + grad_res = tl.where(p > q, 0.5 * scale, -0.5 * scale) + + tl.store(grads_ptr + offsets, grad_res, mask=mask) + + if reduction == _REDUCTION_MODE_NONE: + tl.store(loss_ptr + offsets, tv_loss, mask=mask) + else: + loss_sum += tl.sum(tv_loss, axis=0) + + if reduction != _REDUCTION_MODE_NONE: + # Fuse reduction scaling into loss (same scale as gradients; avoids Python division) + tl.store(loss_ptr, loss_sum * scale) + + +def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label): + BT, V = p.shape + + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + num_warps = get_num_warps(BLOCK_SIZE) + + grid = (BT,) + + reduction = _str_to_reduction_mode[reduction] + + out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,) + output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32) + grads = torch.empty_like(p) + + n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT + + # Pre-compute gradient scale factor (fused into kernel to avoid separate division) + if reduction == _REDUCTION_MODE_BATCHMEAN.value: + scale = 1.0 / n_non_ignore + elif reduction == _REDUCTION_MODE_MEAN.value: + scale = 1.0 / (n_non_ignore * V) + else: + scale = 1.0 + + _tv_distance_kernel[grid]( + p, + p.stride(0), + q, + q.stride(0), + output_tensor, + output_tensor.stride(0), + grads, + grads.stride(0), + shift_labels if has_label else torch.empty(1, device=p.device), + ignore_index, + V, + scale, + BLOCK_SIZE=BLOCK_SIZE, + HAS_LABEL=has_label, + num_warps=num_warps, + reduction=reduction, + ) + + # Loss and gradients are already scaled inside the kernel — no separate division needed + if reduction in (_REDUCTION_MODE_BATCHMEAN.value, _REDUCTION_MODE_MEAN.value): + return output_tensor.sum(), grads + elif reduction == _REDUCTION_MODE_SUM.value: + return output_tensor.sum(dim=0), grads + else: + return output_tensor, grads + + +def tvd_backward_triton(grad_output, grads): + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then. + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + return grads + + return grads * grad_output + + +class LigerTVDLossFunction(torch.autograd.Function): + """ + Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton. + """ + + @staticmethod + @ensure_contiguous + def forward( + ctx, + p: torch.Tensor, + q: torch.Tensor, + shift_labels: Optional[torch.Tensor] = None, + reduction: REDUCTION_LITERAL = "batchmean", + ignore_index: int = -100, + ) -> torch.Tensor: + """A forward pass for the Total Variation Distance Loss. + + Args: + ctx: Torch autograd context + p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution. + q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution. + shift_labels (Optional[torch.Tensor]): A tensor of shape (BT,) containing the labels. + reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean". + ignore_index (int, optional): The index to ignore during loss calculation. Defaults to -100. + + Returns: + torch.Tensor: The computed Total Variation Distance Loss. + """ + has_label = False + if shift_labels is not None: + assert shift_labels.shape == (p.shape[0],), ( + f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}" + ) + shift_labels = shift_labels.contiguous() + has_label = True + + loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label) + ctx.save_for_backward(grads) + return loss + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + """A backward pass for the Total Variation Distance Loss. + + Args: + ctx: Torch autograd context + grad_output (torch.Tensor): The gradient of the loss with respect to the output. + + Returns: + tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs. + """ + (grads,) = ctx.saved_tensors + grads = tvd_backward_triton(grad_output, grads) + + return grads, None, None, None, None diff --git a/src/liger_kernel/ops/utils.py b/src/liger_kernel/ops/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..41c916324a796ac167e09f1759762bb2d3d10cf8 --- /dev/null +++ b/src/liger_kernel/ops/utils.py @@ -0,0 +1,152 @@ +""" +This file incorporates code from Unsloth licensed under the Apache License, Version 2.0. +See the original Unsloth repository at https://github.com/unslothai/unsloth. + +The following line +https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/utils.py#L23 +is based on code from Unsloth, located at: +https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 + +Modifications made by Yanning Chen, 2024. +""" + +import functools +import importlib +import operator + +from typing import Callable + +import torch +import triton +import triton.language as tl + +from packaging.version import Version + +from liger_kernel.utils import infer_device + + +def is_hip() -> bool: + return torch.version.hip is not None + + +def ensure_contiguous(fn): + @functools.wraps(fn) + def wrapper(ctx, *args, **kwargs): + def maybe_to_contiguous(x): + return x.contiguous() if isinstance(x, torch.Tensor) else x + + args = [maybe_to_contiguous(arg) for arg in args] + kwargs = {k: maybe_to_contiguous(v) for k, v in kwargs.items()} + return fn(ctx, *args, **kwargs) + + return wrapper + + +def calculate_settings(n): + # reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 + + MAX_FUSED_SIZE = 65536 + BLOCK_SIZE = triton.next_power_of_2(n) + if BLOCK_SIZE > MAX_FUSED_SIZE: + raise RuntimeError( + f"Cannot launch Triton kernel since n = {n} exceeds the recommended Triton blocksize = {MAX_FUSED_SIZE}." + ) + + num_warps = 4 + if BLOCK_SIZE >= 32768: + num_warps = 32 if not is_hip() else 16 + elif BLOCK_SIZE >= 8192: + num_warps = 16 + elif BLOCK_SIZE >= 2048: + num_warps = 8 + return BLOCK_SIZE, num_warps + + +def compare_version(package: str, operator: Callable, target: str): + try: + pkg = importlib.import_module(package) + except ImportError: + return False + pkg_version = Version(pkg.__version__) + return operator(pkg_version, Version(target)) + + +def get_amp_custom_fwd_bwd() -> Callable: + device = infer_device() + if compare_version("torch", operator.ge, "2.4.0"): + return ( + functools.partial(torch.amp.custom_fwd, device_type=device), + functools.partial(torch.amp.custom_bwd, device_type=device), + ) + if hasattr(torch, "npu") and getattr(torch.npu, "amp", None) is not None: + return torch.npu.amp.custom_fwd, torch.npu.amp.custom_bwd + return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd + + +amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd() + + +torch_to_triton_dtype = { + torch.float32: tl.float32, + torch.float16: tl.float16, + torch.bfloat16: tl.bfloat16, +} + + +@triton.jit +def element_mul_kernel( + X_ptr, + X_stride, + grad_output_ptr, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + """ + This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. + The multiplication is performed in-place on the tensor pointed by X_ptr. + + Parameters: + X_ptr: Pointer to the input tensor. + X_stride (int): The stride of the input tensor. + grad_output_ptr: Pointer to the gradient output value. + n_cols (int): The number of columns in the input tensor. + BLOCK_SIZE (int): The block size for Triton operations. + """ + + # Get the program ID and convert it to int64 to avoid overflow + program_id = tl.program_id(0).to(tl.int64) + + # Locate the start index + X_ptr += program_id * X_stride + + # Load the gradient output value + grad_output = tl.load(grad_output_ptr) + + # Perform the element-wise multiplication + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) + tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) + + +def get_npu_core_count(default: int = 20) -> int: + """Return NPU vector core count. + Fallback to `default` if Triton runtime or NPU device is unavailable. + """ + try: + utils = triton.runtime.driver.active.utils + props = utils.get_device_properties(0) + return int(props.get("num_vectorcore", default)) + except Exception: + return default + + +def set_large_grf_mode(kernel_args: dict): + """Set large GRF mode for XPU devices.""" + # On XPU triton installed along with pytorch-xpu will be called `pytorch-triton-xpu`, + # triton XPU installed from source will be called `triton`. + if compare_version("pytorch-triton-xpu", operator.ge, "3.6.0") or compare_version("triton", operator.ge, "3.6.0"): + kernel_args["grf_mode"] = "256" + else: + # API was changed in https://github.com/intel/intel-xpu-backend-for-triton/pull/5430 + kernel_args["grf_mode"] = "large" diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..b90d05c430d085108dc96fa31cf11ee390c4a6e0 --- /dev/null +++ b/src/liger_kernel/transformers/__init__.py @@ -0,0 +1,233 @@ +import importlib + +from typing import TYPE_CHECKING + +# Always-safe imports (independent of 'transformers') +from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401 +from liger_kernel.transformers.dyt import LigerDyT # noqa: F401 +from liger_kernel.transformers.fused_add_rms_norm import LigerFusedAddRMSNorm # noqa: F401 +from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401 +from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401 +from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401 +from liger_kernel.transformers.jsd import LigerJSD # noqa: F401 +from liger_kernel.transformers.kl_div import LigerKLDIVLoss # noqa: F401 +from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401 +from liger_kernel.transformers.llama4_rope import liger_llama4_text_rotary_pos_emb # noqa: F401 +from liger_kernel.transformers.llama4_rope import liger_llama4_vision_rotary_pos_emb # noqa: F401 +from liger_kernel.transformers.mhc import LigerMHC # noqa: F401 +from liger_kernel.transformers.multi_token_attention import LigerMultiTokenAttention # noqa: F401 +from liger_kernel.transformers.poly_norm import LigerPolyNorm # noqa: F401 +from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401 +from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401 +from liger_kernel.transformers.softmax import LigerSoftmax # noqa: F401 +from liger_kernel.transformers.sparsemax import LigerSparsemax # noqa: F401 +from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401 +from liger_kernel.transformers.swiglu import LigerExperts # noqa: F401 +from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401 +from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP # noqa: F401 +from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401 +from liger_kernel.transformers.tiled_mlp import LigerTiledGEGLUMLP # noqa: F401 +from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP # noqa: F401 +from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401 + +# Static-only imports for IDEs and type checkers +if TYPE_CHECKING: + from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401 + from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401 + from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_exaone4 # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_falcon_h1 # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gpt_oss # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_dense # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_moe # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_internvl # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama4 # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo2 # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo3 # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_paligemma # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_pixtral # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_5_vl # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_5 # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_5_moe # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_moe # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_next # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl_moe # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smollm3 # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smolvlm # noqa: F401 + + +# Check if 'transformers' is installed +try: + import transformers # noqa: F401 + + _TRANSFORMERS_AVAILABLE = True +except ImportError: + _TRANSFORMERS_AVAILABLE = False + + +def is_transformers_available() -> bool: + """ + Returns True if the 'transformers' package is available. + Useful for conditional logic in downstream code. + """ + return _TRANSFORMERS_AVAILABLE + + +def __getattr__(name: str): + """ + Handles lazy access to transformer-dependent attributes. + If 'transformers' is not installed, raises a user-friendly ImportError. + """ + if not _TRANSFORMERS_AVAILABLE: + raise ImportError( + f"The attribute '{name}' requires the 'transformers' library, which is not installed.\n" + f"Please install it with `pip install transformers` to use this functionality." + ) + + if name == "AutoLigerKernelForCausalLM": + module = importlib.import_module("liger_kernel.transformers.auto_model") + return getattr(module, name) + + monkey_patch_symbols = { + "_apply_liger_kernel", + "_apply_liger_kernel_to_instance", + "apply_liger_kernel_to_falcon_h1", + "apply_liger_kernel_to_gemma", + "apply_liger_kernel_to_gemma2", + "apply_liger_kernel_to_gemma3", + "apply_liger_kernel_to_gemma3_text", + "apply_liger_kernel_to_glm4", + "apply_liger_kernel_to_glm4v", + "apply_liger_kernel_to_glm4v_moe", + "apply_liger_kernel_to_gpt_oss", + "apply_liger_kernel_to_granite", + "apply_liger_kernel_to_internvl", + "apply_liger_kernel_to_llama", + "apply_liger_kernel_to_llava", + "apply_liger_kernel_to_llama4", + "apply_liger_kernel_to_mistral", + "apply_liger_kernel_to_mixtral", + "apply_liger_kernel_to_mllama", + "apply_liger_kernel_to_olmo2", + "apply_liger_kernel_to_olmo3", + "apply_liger_kernel_to_paligemma", + "apply_liger_kernel_to_phi3", + "apply_liger_kernel_to_pixtral", + "apply_liger_kernel_to_qwen2", + "apply_liger_kernel_to_qwen2_5_vl", + "apply_liger_kernel_to_qwen2_vl", + "apply_liger_kernel_to_qwen3", + "apply_liger_kernel_to_qwen3_moe", + "apply_liger_kernel_to_qwen3_5", + "apply_liger_kernel_to_qwen3_5_moe", + "apply_liger_kernel_to_qwen3_next", + "apply_liger_kernel_to_qwen3_vl", + "apply_liger_kernel_to_qwen3_vl_moe", + "apply_liger_kernel_to_smollm3", + "apply_liger_kernel_to_smolvlm", + "apply_liger_kernel_to_hunyuan_v1_dense", + "apply_liger_kernel_to_hunyuan_v1_moe", + "apply_liger_kernel_to_exaone4", + } + + if name in monkey_patch_symbols: + module = importlib.import_module("liger_kernel.transformers.monkey_patch") + return getattr(module, name) + + raise AttributeError(f"module {__name__} has no attribute {name}") + + +# Shared symbols in all environments +__all__ = [ + "is_transformers_available", + "LigerCrossEntropyLoss", + "LigerDyT", + "LigerFusedLinearCrossEntropyLoss", + "LigerFusedLinearJSD", + "LigerGEGLUMLP", + "LigerJSD", + "LigerLayerNorm", + "LigerFusedAddRMSNorm", + "LigerPolyNorm", + "LigerRMSNorm", + "liger_rotary_pos_emb", + "liger_llama4_text_rotary_pos_emb", + "liger_llama4_vision_rotary_pos_emb", + "LigerBlockSparseTop2MLP", + "LigerPhi3SwiGLUMLP", + "LigerQwen3MoeSwiGLUMLP", + "LigerSwiGLUMLP", + "LigerTiledGEGLUMLP", + "LigerTiledSwiGLUMLP", + "LigerTVDLoss", + "LigerKLDIVLoss", + "LigerMHC", + "LigerMultiTokenAttention", + "LigerSoftmax", + "LigerSparsemax", +] + +# Add transformer-dependent symbols only if available +if _TRANSFORMERS_AVAILABLE: + __all__.extend( + [ + "AutoLigerKernelForCausalLM", + "_apply_liger_kernel", + "_apply_liger_kernel_to_instance", + "apply_liger_kernel_to_falcon_h1", + "apply_liger_kernel_to_gemma", + "apply_liger_kernel_to_gemma2", + "apply_liger_kernel_to_gemma3", + "apply_liger_kernel_to_gemma3_text", + "apply_liger_kernel_to_glm4", + "apply_liger_kernel_to_glm4v", + "apply_liger_kernel_to_glm4v_moe", + "apply_liger_kernel_to_gpt_oss", + "apply_liger_kernel_to_granite", + "apply_liger_kernel_to_internvl", + "apply_liger_kernel_to_llama", + "apply_liger_kernel_to_llava", + "apply_liger_kernel_to_llama4", + "apply_liger_kernel_to_mistral", + "apply_liger_kernel_to_mixtral", + "apply_liger_kernel_to_mllama", + "apply_liger_kernel_to_olmo2", + "apply_liger_kernel_to_olmo3", + "apply_liger_kernel_to_paligemma", + "apply_liger_kernel_to_phi3", + "apply_liger_kernel_to_pixtral", + "apply_liger_kernel_to_qwen2", + "apply_liger_kernel_to_qwen2_5_vl", + "apply_liger_kernel_to_qwen2_vl", + "apply_liger_kernel_to_qwen3", + "apply_liger_kernel_to_qwen3_moe", + "apply_liger_kernel_to_qwen3_5", + "apply_liger_kernel_to_qwen3_5_moe", + "apply_liger_kernel_to_qwen3_next", + "apply_liger_kernel_to_qwen3_vl", + "apply_liger_kernel_to_qwen3_vl_moe", + "apply_liger_kernel_to_smollm3", + "apply_liger_kernel_to_smolvlm", + "apply_liger_kernel_to_hunyuan_v1_dense", + "apply_liger_kernel_to_hunyuan_v1_moe", + "apply_liger_kernel_to_exaone4", + ] + ) diff --git a/src/liger_kernel/transformers/auto_model.py b/src/liger_kernel/transformers/auto_model.py new file mode 100755 index 0000000000000000000000000000000000000000..004a9808ab631b90f11a8f41c2a3111eaceac66f --- /dev/null +++ b/src/liger_kernel/transformers/auto_model.py @@ -0,0 +1,59 @@ +import inspect +import logging + +from transformers import AutoConfig +from transformers import AutoModelForCausalLM + +from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN +from liger_kernel.transformers.monkey_patch import _apply_liger_kernel + +logger = logging.getLogger(__name__) + + +def _get_model_config(model_dir, **model_init_kwargs): + config = AutoConfig.from_pretrained(model_dir, **model_init_kwargs) + return config + + +class AutoLigerKernelForCausalLM(AutoModelForCausalLM): + """ + This class is a drop-in replacement for AutoModelForCausalLM that applies the Liger Kernel to the model + if applicable. + """ + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + model_config = _get_model_config(pretrained_model_name_or_path, **kwargs) + + # Determine the model type and apply the Liger Kernel if applicable + # Note: _apply_liger_kernel will only pass relevant kwargs to the apply_liger_kernel_to_* function + model_type = model_config.model_type + + _apply_liger_kernel(model_type, **kwargs) + + # Filter out kwargs that were passed to the apply_liger_* function, which will cause + # model initialization errors otherwise + apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type] + apply_fn_signature = inspect.signature(apply_fn) + + applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters} + + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **applicable_kwargs) + + @classmethod + def from_config(cls, config, **kwargs): + model_type = getattr(config, "model_type", None) + if not model_type: + logger.info("Model type could not be determined from model config. No Liger kernels will be applied.") + return + model_type = config.model_type + + _apply_liger_kernel(model_type, **kwargs) + + # Filter out kwargs that were passed to the apply_liger_* function, which will cause + # model initialization errors otherwise + apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type] + apply_fn_signature = inspect.signature(apply_fn) + applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters} + + return super().from_config(config, **applicable_kwargs) diff --git a/src/liger_kernel/transformers/cross_entropy.py b/src/liger_kernel/transformers/cross_entropy.py new file mode 100755 index 0000000000000000000000000000000000000000..ed4310714994e3d75d061a7f87a4c1831df7f95b --- /dev/null +++ b/src/liger_kernel/transformers/cross_entropy.py @@ -0,0 +1,61 @@ +from typing import Optional + +import torch + +from liger_kernel.ops import LigerCrossEntropyFunction +from liger_kernel.transformers.functional import CrossEntropyOutput + + +class LigerCrossEntropyLoss(torch.nn.Module): + def __init__( + self, + weight: Optional[torch.FloatTensor] = None, + ignore_index: int = -100, + lse_square_scale: float = 0.0, + label_smoothing: float = 0.0, + reduction: str = "mean", + softcap: Optional[float] = None, + return_z_loss: bool = False, + return_token_accuracy: bool = False, + return_predicted_tokens: bool = False, + ): + super().__init__() + assert (label_smoothing >= 0) and (label_smoothing <= 1), ( + f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}" + ) + assert reduction in { + "mean", + "sum", + "none", + }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}" + assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}" + self.weight = weight + self.ignore_index = ignore_index + self.lse_square_scale = lse_square_scale + self.label_smoothing = label_smoothing + self.reduction = reduction + self.softcap = softcap + self.return_z_loss = return_z_loss + self.return_token_accuracy = return_token_accuracy + self.return_predicted_tokens = return_predicted_tokens + + def forward(self, _input: torch.Tensor, target: torch.Tensor): + loss, z_loss, token_accuracy, predicted_tokens = LigerCrossEntropyFunction.apply( + _input, + target, + self.weight, + self.ignore_index, + self.lse_square_scale, + self.label_smoothing, + self.reduction, + self.softcap, + self.return_z_loss, + self.return_token_accuracy, + self.return_predicted_tokens, + ) + if not self.return_z_loss and not self.return_token_accuracy and not self.return_predicted_tokens: + return loss + + return CrossEntropyOutput( + loss=loss, z_loss=z_loss, token_accuracy=token_accuracy, predicted_tokens=predicted_tokens + ) diff --git a/src/liger_kernel/transformers/dyt.py b/src/liger_kernel/transformers/dyt.py new file mode 100755 index 0000000000000000000000000000000000000000..8dd0796fc2bc1f6bc7b1e692c7617f0d4e19ea92 --- /dev/null +++ b/src/liger_kernel/transformers/dyt.py @@ -0,0 +1,22 @@ +import torch +import torch.nn as nn + +from liger_kernel.ops import LigerDyTFunction + + +class LigerDyT(nn.Module): + def __init__(self, hidden_size, beta=True, init_alpha=0.5): + super().__init__() + self.hidden_size = hidden_size + self.init_alpha = init_alpha + self.alpha = nn.Parameter(torch.ones(1) * init_alpha) + self.gamma = nn.Parameter(torch.ones(hidden_size)) + self.beta = None + if beta: + self.beta = nn.Parameter(torch.zeros(hidden_size)) + + def forward(self, x): + return LigerDyTFunction.apply(x, self.alpha, self.gamma, self.beta) + + def extra_repr(self): + return f"{self.hidden_size}, init_alpha={self.init_alpha}, beta={self.beta}" diff --git a/src/liger_kernel/transformers/experimental/__init__.py b/src/liger_kernel/transformers/experimental/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..a662f76e5e439a460a2b2c249301a1313ec7348a --- /dev/null +++ b/src/liger_kernel/transformers/experimental/__init__.py @@ -0,0 +1,5 @@ +from liger_kernel.transformers.experimental.embedding import LigerEmbedding # noqa: F401 + +__all__ = [ + "LigerEmbedding", +] diff --git a/src/liger_kernel/transformers/experimental/embedding.py b/src/liger_kernel/transformers/experimental/embedding.py new file mode 100755 index 0000000000000000000000000000000000000000..7c230b885eb2a6f8cce69181040598db2cf2f8f3 --- /dev/null +++ b/src/liger_kernel/transformers/experimental/embedding.py @@ -0,0 +1,26 @@ +from typing import Optional + +import torch +import torch.nn as nn + +from liger_kernel.ops import LigerEmbeddingFunction + + +class LigerEmbedding(nn.Module): + def __init__(self, num_embeddings, embedding_dim, padding_idx: Optional[int] = None): + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.weight = nn.Parameter(torch.randn(num_embeddings, embedding_dim)) + + if padding_idx is not None: + with torch.no_grad(): + self.weight[padding_idx].fill_(0) + + def forward(self, indices): + embedded = LigerEmbeddingFunction.apply(self.weight, indices) + if self.padding_idx is not None: + embedded = embedded.clone() + embedded[indices == self.padding_idx] = 0 + return embedded diff --git a/src/liger_kernel/transformers/fsdp.py b/src/liger_kernel/transformers/fsdp.py new file mode 100755 index 0000000000000000000000000000000000000000..d32bdd2603b4cbebd2b5bb913978f460d00cb179 --- /dev/null +++ b/src/liger_kernel/transformers/fsdp.py @@ -0,0 +1,55 @@ +from typing import Any +from typing import Callable + +from torch.distributed.fsdp import FullyShardedDataParallel + + +class _FSDPForwardRedirection: + """ + Modified based on + https://github.com/Lightning-AI/pytorch-lightning/blob/d3f9c83d6efa4f1def36aa6c199600946cdb9117/src/lightning/pytorch/strategies/strategy.py#L601-L648 + Redirect a method call through FullyShardedDataParallel.forward so that the FSDP module's root pre-forward and + post-forward can be properly executed around the method call. + This is needed in cases where we call a submodule of a FSDP module. For instance, when we want to call only + the `LlamaModel` part out of a FSDP-wrapped `LlamaForCausalLM` to get the hidden states without involving + GPU-memory-heavy `lm_head` and cross entropy computation, doing this directly (i.e. `model.model.forward()`) + will not work because the first `nn.Embedding` layer is not independently wrapped as a FSDP module (because of + the transformer-based wrapping policy), and not calling it through FSDP root module forward will not all-gather + its parameter, thus resulting in "RuntimeError: 'weight' must be 2-D" error. Similarly, if we want to call just + the `lm_head` part of a model, we need this trick too to properly get its params all-gathered. + """ + + def __call__( + self, + wrapper_module: FullyShardedDataParallel, + method: Callable, + *args: Any, + **kwargs: Any, + ): + """Reroutes a method call through the `wrapper_module`'s `forward` method. + Args: + wrapper_module: The module that has `original_module` wrapped. + original_module: The module that was wrapped inside `wrapper_module`. + method_name: The name of the method that should be called on the `original_module` after inputs get + redirected through the `wrapper_module`'s `forward` method. + *args: The positional arguments to the method `method_name`. They will get passed to a patched + `forward` method instead. + **kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched + `forward` method instead. + """ + assert isinstance(wrapper_module, FullyShardedDataParallel) + original_module = wrapper_module._fsdp_wrapped_module + original_forward = original_module.forward + + def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any: + # Unpatch ourselves immediately before calling the method `method_name` + # because itself may want to call the real `forward` + original_module.forward = original_forward # type: ignore[method-assign] + # Call the actual method e.g. `.training_step(...)` + out = method(*_args, **_kwargs) + return out + + # Patch the original_module's forward so we can redirect the arguments back to the real method + original_module.forward = wrapped_forward # type: ignore[method-assign] + wrapper_output = wrapper_module(*args, **kwargs) + return wrapper_output diff --git a/src/liger_kernel/transformers/functional.py b/src/liger_kernel/transformers/functional.py new file mode 100755 index 0000000000000000000000000000000000000000..9fc083ba4e0fbaba1dfc06b990c540af6669a81e --- /dev/null +++ b/src/liger_kernel/transformers/functional.py @@ -0,0 +1,410 @@ +from dataclasses import dataclass +from typing import Optional + +import torch + +from liger_kernel.ops import LigerCrossEntropyFunction +from liger_kernel.ops import LigerDyTFunction +from liger_kernel.ops import LigerFusedAddRMSNormFunction +from liger_kernel.ops import LigerFusedLinearCrossEntropyFunction +from liger_kernel.ops import LigerFusedLinearJSDFunction +from liger_kernel.ops import LigerFusedNeighborhoodAttentionFunction +from liger_kernel.ops import LigerGELUMulFunction +from liger_kernel.ops import LigerGroupNormFunction +from liger_kernel.ops import LigerJSDFunction +from liger_kernel.ops import LigerKLDivLossFunction +from liger_kernel.ops import LigerLayerNormFunction +from liger_kernel.ops import LigerMHCCoeffsFunction +from liger_kernel.ops import LigerMHCPostResFunction +from liger_kernel.ops import LigerMHCPreFunction +from liger_kernel.ops import LigerMultiTokenAttentionFunction +from liger_kernel.ops import LigerPolyNormFunction +from liger_kernel.ops import LigerQwen2VLMRopeFunction +from liger_kernel.ops import LigerRMSNormFunction +from liger_kernel.ops import LigerRopeFunction +from liger_kernel.ops import LigerSiLUMulFunction +from liger_kernel.ops import LigerSoftmaxFunction +from liger_kernel.ops import LigerSparsemaxFunction +from liger_kernel.ops import LigerTVDLossFunction + + +@dataclass +class CrossEntropyOutput: + loss: torch.Tensor + z_loss: Optional[torch.Tensor] = None + token_accuracy: Optional[torch.Tensor] = None + predicted_tokens: Optional[torch.Tensor] = None + + +# conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html +# `weight` and `size_average` are placeholders and not implemented yet +def liger_cross_entropy( + input, + target, + weight=None, + size_average=None, + ignore_index: int = -100, + reduce=None, + reduction: str = "mean", + label_smoothing: float = 0.0, + lse_square_scale: float = 0.0, + softcap: Optional[float] = None, + return_z_loss: bool = False, + return_token_accuracy: bool = False, + return_predicted_tokens: bool = False, +): + loss, z_loss, token_accuracy, predicted_tokens = LigerCrossEntropyFunction.apply( + input, + target, + weight, + ignore_index, + lse_square_scale, + label_smoothing, + reduction, + softcap, + return_z_loss, + return_token_accuracy, + return_predicted_tokens, + ) + + if not return_z_loss and not return_token_accuracy and not return_predicted_tokens: + return loss + + return CrossEntropyOutput( + loss=loss, z_loss=z_loss, token_accuracy=token_accuracy, predicted_tokens=predicted_tokens + ) + + +def liger_fused_linear_cross_entropy( + input, + weight, + target, + bias=None, + ce_weight=None, + ignore_index: int = -100, + lse_square_scale: float = 0.0, + label_smoothing: float = 0.0, + reduction: str = "mean", + softcap: Optional[float] = None, + return_z_loss: bool = False, + accum_dtype=None, + use_token_scaling: bool = False, + return_token_accuracy: bool = False, + return_predicted_tokens: bool = False, +): + loss, z_loss, token_accuracy, predicted_tokens = LigerFusedLinearCrossEntropyFunction.apply( + input, + weight, + target, + bias, + ce_weight, + ignore_index, + lse_square_scale, + label_smoothing, + reduction, + softcap, + return_z_loss, + accum_dtype, + use_token_scaling, + return_token_accuracy, + return_predicted_tokens, + ) + + if not return_z_loss and not return_token_accuracy and not return_predicted_tokens: + return loss + + return CrossEntropyOutput( + loss=loss, z_loss=z_loss, token_accuracy=token_accuracy, predicted_tokens=predicted_tokens + ) + + +def liger_fused_linear_jsd( + student_input, + student_weight, + teacher_input, + teacher_weight, + shift_labels=None, + jsd_beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, +): + return LigerFusedLinearJSDFunction.apply( + student_input, + student_weight, + teacher_input, + teacher_weight, + shift_labels, + jsd_beta, + ignore_index, + temperature, + ) + + +def liger_geglu(a, b): + return LigerGELUMulFunction.apply(a, b) + + +def liger_group_norm( + X, + affine_scaling_weight, + affine_shifting_bias, + num_channels, + num_groups, + eps, +): + return LigerGroupNormFunction.apply( + X, + affine_scaling_weight, + affine_shifting_bias, + num_channels, + num_groups, + eps, + ) + + +def liger_jsd( + input, + target, + shift_labels=None, + beta: float = 0.5, + ignore_index: int = -100, +): + return LigerJSDFunction.apply( + input, + target, + shift_labels, + beta, + ignore_index, + ) + + +# conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.kl_div.html#torch.nn.functional.kl_div +# `size_average` and `mean` are being deprecated in torch API and are placeholders here +def liger_kl_div( + input, + target, + size_average: bool = True, + reduce: bool = True, + reduction: str = "mean", + log_target: bool = False, + eps: float = 1e-10, +): + # Note: the default reduction in torch is `mean`, but being `batchmean` in Liger + return LigerKLDivLossFunction.apply( + input, + target, + reduction, + log_target, + eps, + ) + + +def liger_sparsemax( + input, + dim: int = -1, +): + return LigerSparsemaxFunction.apply(input, dim) + + +def liger_multi_token_attention( + scores, + weight, + bias=None, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + sparse: bool = False, +): + """ + Functional interface for multi-token attention. + + Args: + scores: Input tensor of shape (B, C_in, L, L) + weight: Convolution weight tensor of shape (C_out, C_in // groups, K, K) + bias: Optional bias tensor of shape (C_out,) + stride: Stride for the convolution (default: 1) + padding: Padding for the convolution (default: 0) + dilation: Dilation factor for the convolution (default: 1) + groups: Number of groups for the convolution (default: 1) + sparse: Specifies if input tensors are expected to be sparse (default: False) + Returns: + Output tensor after applying multi-token attention. + """ + return LigerMultiTokenAttentionFunction.apply(scores, weight, bias, stride, padding, dilation, groups, sparse) + + +def liger_fused_neighborhood_attention( + query, + key, + value, + kernel_size: int = 7, + dilation: int = 1, + scale: float = None, +): + """ + Liger fused neighborhood attention. + + paper: https://arxiv.org/pdf/2504.16922 + + Args: + query: Query tensor of shape [batch_size, num_heads, seq_len, head_dim] + key: Key tensor of shape [batch_size, num_heads, seq_len, head_dim] + value: Value tensor of shape [batch_size, num_heads, seq_len, head_dim] + kernel_size: Size of the neighborhood window (default: 7) + dilation: Dilation factor for the neighborhood (default: 1) + scale: Scaling factor for attention scores (default: rsqrt(head_dim)) + + Returns: + Output tensor of shape [batch_size, num_heads, seq_len, head_dim] + """ + return LigerFusedNeighborhoodAttentionFunction.apply(query, key, value, kernel_size, dilation, scale) + + +def liger_tvd( + input, + target, + shift_labels=None, + reduction: str = "mean", + ignore_index: int = -100, +): + return LigerTVDLossFunction.apply( + input, + target, + shift_labels, + reduction, + ignore_index, + ) + + +def liger_layer_norm(X, W, B, eps): + return LigerLayerNormFunction.apply(X, W, B, eps) + + +def liger_qwen2vl_mrope(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim) + + +def liger_rms_norm(X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True): + return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place) + + +def liger_poly_norm(X, W, B, eps=1e-6, in_place=True): + return LigerPolyNormFunction.apply(X, W, B, eps, in_place) + + +def liger_fused_add_rms_norm(X, R, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True): + return LigerFusedAddRMSNormFunction.apply(X, R, W, eps, offset, casting_mode, in_place) + + +def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim) + + +def liger_swiglu(a, b): + return LigerSiLUMulFunction.apply(a, b) + + +def liger_softmax(x): + return LigerSoftmaxFunction.apply(x) + + +def liger_dyt(x, alpha, gamma, beta): + return LigerDyTFunction.apply(x, alpha, gamma, beta) + + +def liger_mhc_coeffs( + x, + phi, + b, + alpha_pre, + alpha_post, + alpha_res, + *, + allow_fp32: bool = False, + tmax: int = 20, + rms_eps: float = 1e-6, + pre_eps: float = 0.0, + sinkhorn_eps: float = 1e-6, + post_mult: float = 2.0, +): + # Convert config scalars to Python types so they are not included in the + # autograd computation graph (they are not learnable parameters). + return LigerMHCCoeffsFunction.apply( + x, + phi, + b, + alpha_pre, + alpha_post, + alpha_res, + allow_fp32, + int(tmax), + float(rms_eps), + float(pre_eps), + float(sinkhorn_eps), + float(post_mult), + ) + + +def liger_mhc_pre(x, h_pre): + return LigerMHCPreFunction.apply(x, h_pre) + + +def liger_mhc_post_res(x, f_out, h_post, h_res): + return LigerMHCPostResFunction.apply(x, f_out, h_post, h_res) + + +def liger_mhc_apply(x, f_out, h_pre, h_post, h_res, *, return_x_in: bool = False): + x_in = liger_mhc_pre(x, h_pre) + x_out = liger_mhc_post_res(x, f_out, h_post, h_res) + if return_x_in: + return x_out, x_in + return x_out + + +def liger_mhc_forward( + x, + layer, + phi, + b, + alpha_pre, + alpha_post, + alpha_res, + *, + allow_fp32=False, + tmax=20, + rms_eps=1e-6, + pre_eps=0.0, + sinkhorn_eps=1e-6, + post_mult=2.0, + return_coeffs=False, +): + """High-level helper: compute coeffs, apply pre, run layer, then apply post+res.""" + h_pre, h_post, h_res = liger_mhc_coeffs( + x, + phi, + b, + alpha_pre, + alpha_post, + alpha_res, + allow_fp32=allow_fp32, + tmax=tmax, + rms_eps=rms_eps, + pre_eps=pre_eps, + sinkhorn_eps=sinkhorn_eps, + post_mult=post_mult, + ) + x_in = liger_mhc_pre(x, h_pre) + layer_dtype = x_in.dtype + if hasattr(layer, "parameters"): + try: + layer_dtype = next(layer.parameters()).dtype + except StopIteration: + layer_dtype = x_in.dtype + if x_in.dtype != layer_dtype: + x_in = x_in.to(layer_dtype) + f_out = layer(x_in) + x_out = liger_mhc_post_res(x, f_out, h_post, h_res) + if return_coeffs: + return x_out, (h_pre, h_post, h_res) + return x_out diff --git a/src/liger_kernel/transformers/fused_add_rms_norm.py b/src/liger_kernel/transformers/fused_add_rms_norm.py new file mode 100755 index 0000000000000000000000000000000000000000..e9bd0baa6aaba62326b8d6e00ad4b4daba99692f --- /dev/null +++ b/src/liger_kernel/transformers/fused_add_rms_norm.py @@ -0,0 +1,39 @@ +import torch +import torch.nn as nn + +from liger_kernel.ops import LigerFusedAddRMSNormFunction + + +class LigerFusedAddRMSNorm(nn.Module): + def __init__( + self, + hidden_size, + eps=1e-6, + offset=0.0, + casting_mode="llama", + init_fn="ones", + in_place=False, + ): + super().__init__() + assert init_fn in [ + "ones", + "zeros", + ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}" + self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)) + self.variance_epsilon, self.offset, self.casting_mode, self.in_place = (eps, offset, casting_mode, in_place) + + def forward(self, hidden_states, residual): + return LigerFusedAddRMSNormFunction.apply( + hidden_states, + residual, + self.weight, + self.variance_epsilon, + self.offset, + self.casting_mode, + self.in_place, + ) + + def extra_repr(self): + return ( + f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}" + ) diff --git a/src/liger_kernel/transformers/fused_linear_cross_entropy.py b/src/liger_kernel/transformers/fused_linear_cross_entropy.py new file mode 100755 index 0000000000000000000000000000000000000000..c4a4474ce75c713c7199a786ed2408886ebfb0cb --- /dev/null +++ b/src/liger_kernel/transformers/fused_linear_cross_entropy.py @@ -0,0 +1,69 @@ +from typing import Optional + +import torch + +from liger_kernel.ops import LigerFusedLinearCrossEntropyFunction +from liger_kernel.transformers.functional import CrossEntropyOutput + + +class LigerFusedLinearCrossEntropyLoss(torch.nn.Module): + def __init__( + self, + ce_weight: Optional[torch.FloatTensor] = None, + ignore_index: int = -100, + lse_square_scale: float = 0.0, + label_smoothing: float = 0.0, + reduction: str = "mean", + softcap: Optional[float] = None, + return_z_loss: bool = False, + accum_dtype: Optional[torch.dtype] = None, + use_token_scaling: bool = False, + return_token_accuracy: bool = False, + return_predicted_tokens: bool = False, + ): + super().__init__() + assert (label_smoothing >= 0) and (label_smoothing <= 1), ( + f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}" + ) + assert reduction in { + "mean", + "sum", + "none", + }, f"reduction must be 'mean' or 'sum' or 'none'. Got: {reduction}" + assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}" + self.ce_weight = ce_weight + self.ignore_index = ignore_index + self.lse_square_scale = lse_square_scale + self.label_smoothing = label_smoothing + self.reduction = reduction + self.softcap = softcap + self.return_z_loss = return_z_loss + self.accum_dtype = accum_dtype + self.use_token_scaling = use_token_scaling + self.return_token_accuracy = return_token_accuracy + self.return_predicted_tokens = return_predicted_tokens + + def forward(self, lin_weight, _input, target, bias=None): + loss, z_loss, token_accuracy, predicted_tokens = LigerFusedLinearCrossEntropyFunction.apply( + _input, + lin_weight, + target, + bias, + self.ce_weight, + self.ignore_index, + self.lse_square_scale, + self.label_smoothing, + self.reduction, + self.softcap, + self.return_z_loss, + self.accum_dtype, + self.use_token_scaling, + self.return_token_accuracy, + self.return_predicted_tokens, + ) + if not self.return_z_loss and not self.return_token_accuracy and not self.return_predicted_tokens: + return loss + + return CrossEntropyOutput( + loss=loss, z_loss=z_loss, token_accuracy=token_accuracy, predicted_tokens=predicted_tokens + ) diff --git a/src/liger_kernel/transformers/fused_linear_jsd.py b/src/liger_kernel/transformers/fused_linear_jsd.py new file mode 100755 index 0000000000000000000000000000000000000000..38f668c6f88397c3d58a004dc769daff016644b5 --- /dev/null +++ b/src/liger_kernel/transformers/fused_linear_jsd.py @@ -0,0 +1,95 @@ +from typing import Optional + +import torch + +from liger_kernel.ops import LigerFusedLinearJSDFunction + + +class LigerFusedLinearJSD(torch.nn.Module): + r"""Fusing the last linear layer with generalized JSD + + Handle the forward and backward pass of the final linear layer via JSD by avoiding + the materialization of the large logits tensor. + + Args: + jsd_beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5` + ignore_index (int): The index to ignore in the target. Default: `-100` + temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0` + + Shape: + - student_input: :math:`(BT, H)`, where B is batch size, T is sequence length, H is hidden dimension. + - student_weight: :math:`(V, H)`, where V is vocab size. + - teacher_input: :math:`(BT, H')`, where H' is hidden dimension of the teacher model. + - teacher_weight: :math:`(V, H')`, where hidden size H and H' can be different. + - shift_labels: :math:`(BT,)` + - Output: a scalar. + + Examples: + ```python + >>> (B, T, H_s, H_t, V) = (2, 2, 3, 5, 10) + >>> fused_jsd = LigerFusedLinearJSD(jsd_beta=0.1, temperature=2.0) + >>> # generate inputs and weights + >>> student_input = torch.rand(B * T, H_s, device="cuda", requires_grad=True) + >>> student_lin = torch.nn.Linear(H_s, V, bias=False, device="cuda") + >>> # teacher input doesn't require grad, hidden_dim can be different from student's + >>> teacher_input = torch.rand(B * T, H_t, device="cuda") + >>> teacher_lin = torch.nn.Linear(H_t, V, bias=False, device="cuda") + >>> output = fused_jsd(student_input, student_lin.weight, teacher_input, teacher_lin.weight) + >>> output.backward() + >>> + >>> # Example with labels for supervised fine-tuning (SFT) context: + >>> + >>> # Assume hidden_states, lm_heads and corresponding labels are given + >>> student_lm_head = torch.nn.Linear(H_s, V, bias=False) + >>> student_hidden_states = torch.randn(B * T, H_s, requires_grad=True).log_softmax(dim=-1) + >>> teacher_lm_head = torch.nn.Linear(H_t, V, bias=False) + >>> teacher_hidden_states = torch.randn(B * T, H_t).log_softmax(dim=-1) + >>> labels = torch.randint(0, V, (B * T,), torch.long) + >>> + >>> # Shift so that tokens < n predict n + >>> shift_student_hidden_states = student_hidden_states[..., :-1, :].contiguous() + >>> shift_teacher_hidden_states = teacher_hidden_states[..., :-1, :].contiguous() + >>> shift_labels = labels[..., 1:].contiguous() + >>> + >>> # Flatten tokens + >>> shift_student_hidden_states = shift_student_hidden_states.view(-1, V) + >>> shift_teacher_hidden_states = shift_teacher_hidden_states.view(-1, V) + >>> shift_labels = shift_labels.view(-1) + >>> + >>> # Calculate loss + >>> loss_fct = LigerJSD(beta=0.1) + >>> loss = loss_fct( + >>> shift_studetn_hidden_states, + >>> student_lm_head.weight, + >>> shift_teacher_hidden_states, + >>> teacher_lm_head.weight, + >>> shift_labels + >>> ) + ``` + """ + + def __init__(self, jsd_beta=0.5, ignore_index=-100, temperature=1.0): + super().__init__() + assert temperature != 0, "temperature cannot be 0." + self.jsd_beta = jsd_beta + self.temperature = temperature + self.ignore_index = ignore_index + + def forward( + self, + student_input: torch.Tensor, + student_weight: torch.Tensor, + teacher_input: torch.Tensor, + teacher_weight: torch.Tensor, + shift_labels: Optional[torch.LongTensor], + ): + return LigerFusedLinearJSDFunction.apply( + student_input, + student_weight, + teacher_input, + teacher_weight, + shift_labels, + self.jsd_beta, + self.ignore_index, + self.temperature, + ) diff --git a/src/liger_kernel/transformers/fused_neighborhood_attention.py b/src/liger_kernel/transformers/fused_neighborhood_attention.py new file mode 100755 index 0000000000000000000000000000000000000000..92a3e8503379156c10cc7d0361ee05c833053ac8 --- /dev/null +++ b/src/liger_kernel/transformers/fused_neighborhood_attention.py @@ -0,0 +1,234 @@ +import math + +from typing import Optional + +import torch +import torch.nn as nn + +from liger_kernel.ops import LigerFusedNeighborhoodAttentionFunction + + +class LigerFusedNeighborhoodAttention(nn.Module): + """ + Liger Fused Neighborhood Attention Module. + + Paper: https://arxiv.org/pdf/2504.16922 + + Fused Neighborhood attention restricts the attention mechanism to a local neighborhood + around each position, reducing computational complexity from O(n²) to O(n*k) + where k is the neighborhood size. + + Args: + hidden_size (int): The hidden dimension size + num_heads (int): Number of attention heads + kernel_size (int): Size of the neighborhood window (default: 7) + dilation (int): Dilation factor for the neighborhood (default: 1) + bias (bool): Whether to use bias in linear projections (default: True) + dropout (float): Dropout probability (default: 0.0) + scale (Optional[float]): Scaling factor for attention scores. + If None, uses 1/sqrt(head_dim) (default: None) + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + kernel_size: int = 7, + dilation: int = 1, + bias: bool = True, + dropout: float = 0.0, + scale: Optional[float] = None, + ): + super().__init__() + + if hidden_size % num_heads != 0: + raise ValueError(f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})") + + if kernel_size <= 0: + raise ValueError(f"kernel_size ({kernel_size}) must be positive") + + if kernel_size % 2 == 0: + raise ValueError(f"kernel_size ({kernel_size}) must be odd") + + if dilation < 1: + raise ValueError(f"dilation ({dilation}) must be positive") + + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.kernel_size = kernel_size + self.dilation = dilation + self.scale = scale if scale is not None else 1.0 / math.sqrt(self.head_dim) + self.dropout_p = dropout + + self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + + self.out_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + + if dropout > 0.0: + self.dropout = nn.Dropout(dropout) + else: + self.dropout = None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass of the fused neighborhood attention module. + + Args: + hidden_states (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size] + attention_mask (Optional[torch.Tensor]): Attention mask (currently not supported) + + Returns: + torch.Tensor: Output tensor of shape [batch_size, seq_len, hidden_size] + """ + if attention_mask is not None: + raise NotImplementedError("Attention mask is not yet supported in LigerFusedNeighborhoodAttention") + + batch_size, seq_len, hidden_size = hidden_states.shape + + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + + attn_output = LigerFusedNeighborhoodAttentionFunction.apply( + query, key, value, self.kernel_size, self.dilation, self.scale + ) + + attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size) + + if self.dropout is not None: + attn_output = self.dropout(attn_output) + + output = self.out_proj(attn_output) + + return output + + def extra_repr(self) -> str: + return ( + f"hidden_size={self.hidden_size}, num_heads={self.num_heads}, " + f"head_dim={self.head_dim}, kernel_size={self.kernel_size}, " + f"dilation={self.dilation}, scale={self.scale}, dropout={self.dropout_p}" + ) + + +class LigerFusedNeighborhoodAttentionLayer(nn.Module): + """ + A complete neighborhood attention layer with layer norm and residual connection. + + Args: + hidden_size (int): The hidden dimension size + num_heads (int): Number of attention heads + kernel_size (int): Size of the neighborhood window (default: 7) + dilation (int): Dilation factor for the neighborhood (default: 1) + bias (bool): Whether to use bias in linear projections (default: True) + dropout (float): Dropout probability (default: 0.0) + layer_norm_eps (float): Epsilon for layer normalization (default: 1e-5) + scale (Optional[float]): Scaling factor for attention scores (default: None) + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + kernel_size: int = 7, + dilation: int = 1, + bias: bool = True, + dropout: float = 0.0, + layer_norm_eps: float = 1e-5, + scale: Optional[float] = None, + ): + super().__init__() + + self.attention = LigerFusedNeighborhoodAttention( + hidden_size=hidden_size, + num_heads=num_heads, + kernel_size=kernel_size, + dilation=dilation, + bias=bias, + dropout=dropout, + scale=scale, + ) + + self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + + if dropout > 0.0: + self.dropout = nn.Dropout(dropout) + else: + self.dropout = None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass with residual connection and layer normalization. + + Args: + hidden_states (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size] + attention_mask (Optional[torch.Tensor]): Attention mask (currently not supported) + + Returns: + torch.Tensor: Output tensor of shape [batch_size, seq_len, hidden_size] + """ + normed_hidden_states = self.layer_norm(hidden_states) + + attn_output = self.attention(normed_hidden_states, attention_mask) + + if self.dropout is not None: + attn_output = self.dropout(attn_output) + + output = hidden_states + attn_output + + return output + + +class LigerFusedNeighborhoodAttentionConfig: + """ + Configuration class for Fused Neighborhood Attention. + + This can be used to easily configure neighborhood attention parameters + for different model architectures. + """ + + def __init__( + self, + hidden_size: int = 768, + num_heads: int = 12, + kernel_size: int = 7, + dilation: int = 1, + bias: bool = True, + dropout: float = 0.0, + layer_norm_eps: float = 1e-5, + scale: Optional[float] = None, + ): + self.hidden_size = hidden_size + self.num_heads = num_heads + self.kernel_size = kernel_size + self.dilation = dilation + self.bias = bias + self.dropout = dropout + self.layer_norm_eps = layer_norm_eps + self.scale = scale + + def to_dict(self): + return { + "hidden_size": self.hidden_size, + "num_heads": self.num_heads, + "kernel_size": self.kernel_size, + "dilation": self.dilation, + "bias": self.bias, + "dropout": self.dropout, + "layer_norm_eps": self.layer_norm_eps, + "scale": self.scale, + } diff --git a/src/liger_kernel/transformers/geglu.py b/src/liger_kernel/transformers/geglu.py new file mode 100755 index 0000000000000000000000000000000000000000..fb72cbbab508c1a2f20afda39ce2b8c2f60ed784 --- /dev/null +++ b/src/liger_kernel/transformers/geglu.py @@ -0,0 +1,22 @@ +import torch.nn as nn + +from liger_kernel.ops import LigerGELUMulFunction + + +class LigerGEGLUMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + # TODO: support exact GELU + # Right now Gemma 1, 1.1 and 2 models are all using `gelu_pytorch_tanh` + # https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/gemma/modeling_gemma.py#L175 + # https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/activations.py#L46 + # So we can safely assume we use tanh approximation form all the time + + def forward(self, x): + return self.down_proj(LigerGELUMulFunction.apply(self.gate_proj(x), self.up_proj(x))) diff --git a/src/liger_kernel/transformers/group_norm.py b/src/liger_kernel/transformers/group_norm.py new file mode 100755 index 0000000000000000000000000000000000000000..efc6f8ac157e0beacbcd8d0d8b31f3ba16e64073 --- /dev/null +++ b/src/liger_kernel/transformers/group_norm.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn + +from liger_kernel.ops import LigerGroupNormFunction + + +class LigerGroupNorm(nn.Module): + def __init__(self, num_channels, num_groups, eps=1e-6, bias=False, init_fn="ones"): + """ + A Group Normalization layer. + Args: + num_channels (int): Number of channels in the input tensor. + num_groups (int): Number of groups to divide the channels into. + eps (float, optional): A value added to the denominator for numerical stability. Default: 1e-6. + bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``False``. + init_fn (str, optional): Initialization function for the learnable parameters. Default: "ones". + """ + super().__init__() + assert init_fn in [ + "ones", + "zeros", + ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}" + + assert num_channels % num_groups == 0, ( + f"Number of channels {num_channels} must be divisible by num_groups {num_groups}" + ) + self.num_channels = num_channels + self.num_groups = num_groups + self.eps = eps + self.weight = nn.Parameter(torch.ones(num_channels) if init_fn == "ones" else torch.zeros(num_channels)) + self.bias = nn.Parameter(torch.randn(num_channels) if bias else torch.zeros(num_channels)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # hidden_states: (batch_size, num_channels, *) + assert hidden_states.dim() >= 3, f"Input must have atleast 3 dimensions, got {hidden_states.dim()}" + assert hidden_states.size(1) == self.num_channels, ( + f"Input tensor must have {self.num_channels} channels, got {hidden_states.size(1)}" + ) + return LigerGroupNormFunction.apply( + hidden_states, + self.weight, + self.bias, + self.num_channels, + self.num_groups, + self.variance_epsilon, + ) + + def extra_repr(self): + return f"{self.hidden_size}, num_channels={self.num_channels}, num_groups={self.num_groups}, eps={self.eps}" diff --git a/src/liger_kernel/transformers/grpo_loss.py b/src/liger_kernel/transformers/grpo_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..caa053bd6e74acfa5a035fc06fd5dfdff2c0aeb5 --- /dev/null +++ b/src/liger_kernel/transformers/grpo_loss.py @@ -0,0 +1,206 @@ +import torch + +from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase +from liger_kernel.ops import GrpoLossFunction + + +def triton_grpo_loss( + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask=None, + temperature=0.9, + beta=0.04, + eps_low=0.2, + eps_high=0.4, + inplace=True, + loss_type="dapo", + max_completion_length=None, + importance_sampling_level="token", + reduce=False, + sapo_temperature_pos=1.0, + sapo_temperature_neg=1.05, + vllm_is_ratio=None, + delta=None, + use_bias_correction_kl=False, +): + """ + Triton-optimized GRPO loss function. + + Args: + logits: Model logits (B, L+1, V) + old_logp: Old policy log probabilities (B, L) or None + ref_logp: Reference model log probabilities (B, L) or None (required if beta != 0) + completion_ids: Token IDs for completions (B, L) + advantages: Per-sequence advantages (B,) + completion_mask: Mask for valid tokens (B, L) or None + temperature: Temperature for log softmax + beta: KL penalty coefficient + eps_low: Lower clipping bound for importance ratio + eps_high: Upper clipping bound for importance ratio + inplace: Whether to modify logits in-place during backward + loss_type: Loss reduction type ("grpo", "bnpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo") + max_completion_length: Max completion length for dr_grpo loss type; defaults to sequence length if None + importance_sampling_level: "token" or "sequence" importance sampling + reduce: If True, return reduced loss; if False, return per-token loss + vllm_is_ratio: vLLM importance sampling ratio (B, L) or (B, 1) or None. + Used to correct for distribution mismatch when using vLLM for generation. + Applied to PPO loss BEFORE adding KL penalty. + delta: Upper clamp for two-sided clipping (INTELLECT-2). When set, coef_1 is clamped + to max=delta before computing the PPO loss. Only supported for standard PPO loss + types (grpo, bnpo, dr_grpo, dapo, luspo). None means disabled. + use_bias_correction_kl: If True, multiply KL divergence by coef_1 (importance sampling + ratio) for bias-corrected KL estimation (DeepSeek-V3.2). Default False. + + Returns: + If reduce=True: (loss, metrics) where metrics = [kl_mean, clip_ratio] or [clip_ratio] + If reduce=False: (per_token_loss, per_token_kl, is_clipped) + """ + assert logits is not None and completion_ids is not None and advantages is not None, ( + "must provide logits, completion_ids and advantages" + ) + assert importance_sampling_level in ("token", "sequence"), ( + f"importance_sampling_level must be 'token' or 'sequence', got {importance_sampling_level}" + ) + + result = GrpoLossFunction.apply( + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + inplace, + loss_type, + max_completion_length, + reduce, + importance_sampling_level, + sapo_temperature_pos, + sapo_temperature_neg, + vllm_is_ratio, + delta, + use_bias_correction_kl, + ) + + if not reduce: + # Returns (per_token_loss, per_token_kl, is_clipped) - all (B, L) tensors + return result + + # reduce=True: Returns (reduced_loss, kl_mean, clip_ratio) - all scalars + reduced_loss, kl_mean, clip_ratio = result + metrics = [] + if beta != 0.0 and kl_mean is not None: + metrics.append(kl_mean) + metrics.append(clip_ratio) + return reduced_loss, metrics + + +def _reduce_grpo_loss(per_token_loss, completion_mask, loss_type, max_completion_length): + mask = completion_mask + if mask is None: + mask = torch.ones_like(per_token_loss, dtype=per_token_loss.dtype, device=per_token_loss.device) + mask = mask.to(per_token_loss.dtype) + + if loss_type == "grpo" or loss_type == "sapo": + # SAPO uses the same normalization as GRPO (per-sequence average) + per_seq = (per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0) + return per_seq.mean() + if loss_type == "bnpo": + return (per_token_loss * mask).sum() / mask.sum().clamp(min=1.0) + if loss_type == "dr_grpo": + batch = per_token_loss.shape[0] + max_len = max_completion_length if max_completion_length is not None else per_token_loss.shape[1] + return (per_token_loss * mask).sum() / (batch * max_len) + if loss_type == "dapo" or loss_type == "cispo": + # CISPO uses the same normalization as DAPO + normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(mask) + return (per_token_loss * mask).sum() / normalizer + if loss_type == "luspo": + # LUSPO: scale each sequence's loss by its valid token count, then average across sequences + return (per_token_loss * mask.sum(-1, keepdim=True)).mean() + raise ValueError(f"Unsupported loss_type '{loss_type}' for Triton GRPO loss.") + + +def _masked_mean(values, mask): + if mask is None: + mask = torch.ones_like(values, dtype=values.dtype, device=values.device) + mask = mask.to(values.dtype) + return (values * mask).sum() / mask.sum().clamp(min=1.0) + + +# This is a demo how to use grpo_loss in GRPOTrainer. The Trl version must be 0.26.2+ +""" +import torch +import trl +from packaging.version import Version +assert Version(trl.__version__) >= Version("0.26.2"), "please pip install trl>=0.26.2" +from trl.extras.profiling import profiling_decorator + +@profiling_decorator +def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits + return fused_selective_log_softmax(logits, input_ids, self.temperature, mask=attention_mask) + +@profiling_decorator +def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + if return_outputs: + raise ValueError("The GRPOTrainer does not support returning outputs") + # Compute the per-token log probabilities for the model + + prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] + completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits + + ref_per_token_logps = inputs["ref_per_token_logps"] + advantages = inputs["advantages"] + old_per_token_logps = inputs["old_per_token_logps"] + + # Get vLLM importance sampling ratio if using vLLM with importance sampling correction + vllm_is_ratio = inputs.get("importance_sampling_ratio", None) + + per_token_loss, per_token_kl, is_clipped = triton_grpo_loss( + logits, + old_per_token_logps, + ref_per_token_logps, + completion_ids, + advantages, + completion_mask, + temperature=self.temperature, + beta=self.beta, + eps_low=self.epsilon_low, + eps_high=self.epsilon_high, + importance_sampling_level=self.importance_sampling_level, # "token" or "sequence" + vllm_is_ratio=vllm_is_ratio, # vLLM distribution correction + ) + loss = (per_token_loss * completion_mask).sum() / completion_mask.sum() + + # Log the metrics + mode = "eval" if self.control.should_evaluate else "train" + + if self.beta != 0.0: + mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum() + self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item()) + + clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum() + self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item()) + return loss + +trl.GRPOTrainer._get_per_token_logps = _get_per_token_logps +trl.GRPOTrainer.compute_loss = compute_loss +trigger = None +""" + +# add this line at the first line of grpo.py in open-r1 +""" +from liger_kernel.transformers.grpo_loss import trigger +""" diff --git a/src/liger_kernel/transformers/jsd.py b/src/liger_kernel/transformers/jsd.py new file mode 100755 index 0000000000000000000000000000000000000000..a8489e7d1a42a2001c6846eeaf228571267474c5 --- /dev/null +++ b/src/liger_kernel/transformers/jsd.py @@ -0,0 +1,70 @@ +from typing import Optional + +import torch + +from liger_kernel.ops import LigerJSDFunction + + +class LigerJSD(torch.nn.Module): + r"""The generalized Jensen-Shannon Divergence. + .. math:: + JSD(\beta)(P || Q) + = \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q)) + .. note:: + As all the other losses in PyTorch, this function expects the first argument, + :attr:`log_q`, to be the predictions, the output of the student model in log-space, + and the second, :attr:`log_p`, to be the observations, the output of the teacher model in log-space. + This differs from the standard mathematical notation :math:`JSD(P || Q)` where + :math:`P` denotes the teacher model and :math:`Q` denotes the student model. + + Args: + beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5` + ignore_index (int): The index to ignore in the target. Default: `-100` + + Shape: + - Input: :math:`(BT, V)`, where B is batch size, T is sequence length, V is vocab size. + - Target: :math:`(BT, V)`, same shape as the input. + - shift_labels (Optional): :math:`(BT,)` + - Output: a scalar. + + Examples: + ```python + >>> (B, T, V) = (2, 2, 5) + >>> jsd = LigerJSD(beta=0.1) + >>> # input should be a distribution in the log space + >>> input = torch.randn(B * T, V, requires_grad=True).log_softmax(dim=-1) + >>> target = torch.randn(B * T, V).log_softmax(dim=-1) + >>> output = jsd(input, target) + >>> + >>> # Example with labels for supervised fine-tuning (SFT) context + >>> # Assume logits and corresponding labels are given + >>> student_logits = torch.randn(B * T, V, requires_grad=True).log_softmax(dim=-1) + >>> teacher_logits = torch.randn(B * T, V).log_softmax(dim=-1) + >>> labels = torch.randint(0, V, (B * T,), torch.long) + >>> # Shift so that tokens < n predict n + >>> shift_student_logits = student_logits[..., :-1, :].contiguous() + >>> shift_teacher_logits = teacher_logits[..., :-1, :].contiguous() + >>> shift_labels = labels[..., 1:].contiguous() + >>> # Flatten tokens + >>> shift_student_logits = shift_student_logits.view(-1, V) + >>> shift_teacher_logits = shift_teacher_logits.view(-1, V) + >>> shift_labels = shift_labels.view(-1) + >>> # Calculate loss + >>> loss_fct = LigerJSD(beta=0.1) + >>> loss = loss_fct(shift_studetn_logits, shift_teacher_logits, shift_labels) + + ``` + """ + + def __init__(self, beta: float = 0.5, ignore_index: int = -100): + super().__init__() + self.beta = beta + self.ignore_index = ignore_index + + def forward( + self, + log_q: torch.Tensor, + log_p: torch.Tensor, + shift_labels: Optional[torch.LongTensor] = None, + ): + return LigerJSDFunction.apply(log_q, log_p, shift_labels, self.beta, self.ignore_index) diff --git a/src/liger_kernel/transformers/kl_div.py b/src/liger_kernel/transformers/kl_div.py new file mode 100755 index 0000000000000000000000000000000000000000..97d9e68c591e7a7a8e48f4b529fce320511d9948 --- /dev/null +++ b/src/liger_kernel/transformers/kl_div.py @@ -0,0 +1,12 @@ +import torch.nn as nn + +from liger_kernel.ops import LigerKLDivLossFunction + + +class LigerKLDIVLoss(nn.KLDivLoss): + def __init__(self, eps: float = 1e-10, *args, **kwargs): + super(LigerKLDIVLoss, self).__init__(*args, **kwargs) + self.eps = eps + + def forward(self, y_pred, y_true): + return LigerKLDivLossFunction.apply(y_pred, y_true, self.reduction, self.log_target, self.eps) diff --git a/src/liger_kernel/transformers/layer_norm.py b/src/liger_kernel/transformers/layer_norm.py new file mode 100755 index 0000000000000000000000000000000000000000..34a6325117a65cc99967ba86805a68eb84a2ffa8 --- /dev/null +++ b/src/liger_kernel/transformers/layer_norm.py @@ -0,0 +1,24 @@ +import torch +import torch.nn as nn + +from liger_kernel.ops import LigerLayerNormFunction + + +class LigerLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6, bias=False, init_fn="ones"): + super().__init__() + assert init_fn in [ + "ones", + "zeros", + ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}" + self.hidden_size = hidden_size + self.eps = eps + self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)) + self.bias = nn.Parameter(torch.randn(hidden_size) if bias else torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + return LigerLayerNormFunction.apply(hidden_states, self.weight, self.bias, self.variance_epsilon) + + def extra_repr(self): + return f"{self.hidden_size}, eps={self.eps}" diff --git a/src/liger_kernel/transformers/llama4_rope.py b/src/liger_kernel/transformers/llama4_rope.py new file mode 100755 index 0000000000000000000000000000000000000000..d808fdec6ccdcd51ff51268f999b4dec43d15305 --- /dev/null +++ b/src/liger_kernel/transformers/llama4_rope.py @@ -0,0 +1,93 @@ +""" +Liger Kernel implementation of Llama4 Rotary Position Embedding (RoPE). +Supports both text and vision RoPE variants with fused operations for optimal performance. +""" + +import torch + +from liger_kernel.ops import LigerLlama4RopeFunction + + +def liger_llama4_text_rotary_pos_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Liger-optimized implementation of Llama4 text rotary position embedding. + + This implementation uses a fused Triton kernel for complex multiplication, + providing significant performance improvements over the original PyTorch implementation. + + Args: + xq (torch.Tensor): Query tensor of shape (batch_size, seq_len, num_heads, head_dim) + xk (torch.Tensor): Key tensor of shape (batch_size, seq_len, num_heads, head_dim) + freqs_cis (torch.Tensor): Complex frequency tensor from Llama4TextRotaryEmbedding + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors + """ + # Use fused Triton kernel for complex RoPE + return LigerLlama4RopeFunction.apply(xq, xk, freqs_cis) + + +def liger_llama4_vision_rotary_pos_emb( + query: torch.Tensor, + key: torch.Tensor, + freqs_ci: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Liger-optimized implementation of Llama4 vision rotary position embedding. + + This implementation uses the same fused Triton kernel as text RoPE, + providing performance improvements for vision transformer attention. + + Args: + query (torch.Tensor): Query tensor of shape (batch_size, seq_len, num_heads, head_dim) + key (torch.Tensor): Key tensor of shape (batch_size, seq_len, num_heads, head_dim) + freqs_ci (torch.Tensor): Complex frequency tensor for 2D positions + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors + """ + # Handle broadcasting for vision RoPE + if freqs_ci.dim() == 3: + try: + # Try the regular 3D expansion + freqs_ci = freqs_ci.unsqueeze(0).expand(query.shape[0], -1, -1) + except RuntimeError as e: + if "expand" in str(e) and "4" in str(e): + # The tensor is actually 4D internally, handle it differently + freqs_ci = freqs_ci.squeeze(1) # Remove the middle dimension + freqs_ci = freqs_ci.unsqueeze(0).expand(query.shape[0], -1, -1) + else: + raise e + elif freqs_ci.dim() == 4: # (1, seq_len, 1, head_dim//2) - already properly shaped + # Squeeze the middle dimension to get (1, seq_len, head_dim//2) + freqs_ci = freqs_ci.squeeze(2) + elif freqs_ci.dim() == 2: # (seq_len, head_dim//2) - needs expansion + freqs_ci = freqs_ci.unsqueeze(0).expand(query.shape[0], -1, -1) + else: + raise ValueError(f"Unexpected freqs_ci shape: {freqs_ci.shape}") + + # Use the same fused kernel as text RoPE + return LigerLlama4RopeFunction.apply(query, key, freqs_ci) + + +# Note: We only patch the functions, not the classes +# The original Llama4TextRotaryEmbedding and Llama4VisionRotaryEmbedding classes remain unchanged + + +# Convenience functions for monkey patching +def apply_liger_llama4_rope_full(modeling_module): + """ + Apply Liger optimizations to Llama4 RoPE functions. + + Args: + modeling_module: The transformers modeling module to patch + """ + # Replace the text RoPE function + modeling_module.apply_rotary_emb = liger_llama4_text_rotary_pos_emb + + # Replace the vision RoPE function + modeling_module.vision_apply_rotary_emb = liger_llama4_vision_rotary_pos_emb diff --git a/src/liger_kernel/transformers/mhc.py b/src/liger_kernel/transformers/mhc.py new file mode 100755 index 0000000000000000000000000000000000000000..30459dfbe98eb0e924956c12a9fef90988c35d2b --- /dev/null +++ b/src/liger_kernel/transformers/mhc.py @@ -0,0 +1,162 @@ +import warnings + +import torch +import torch.nn as nn + +from liger_kernel.transformers.functional import liger_mhc_coeffs +from liger_kernel.transformers.functional import liger_mhc_post_res +from liger_kernel.transformers.functional import liger_mhc_pre + + +class LigerMHC(nn.Module): + """ + Manifold-Constrained Hyper-Connections (mHC) wrapper. + + Wraps an arbitrary layer ``F: [..., C] -> [..., C]`` with multiple residual + streams, following the mHC architecture (arXiv:2512.24880). The input is a + multi-stream tensor of shape ``[..., HC, C]`` where ``HC`` is the number of + residual streams. + + The forward pass performs: + + 1. **Coefficients** -- Compute data-dependent routing coefficients + (``h_pre``, ``h_post``, ``h_res``) via a fused matmul + RMS + normalization + Sinkhorn-Knopp iterations. + 2. **Pre-aggregate** -- ``x_in = sum_i h_pre[i] * x[i]`` + (shape: ``[..., C]``) + 3. **Layer** -- ``f_out = layer(x_in)`` (shape: ``[..., C]``) + 4. **Post + residual** -- + ``x_out[o] = sum_i h_res[o,i] * x[i] + h_post[o] * f_out`` + (shape: ``[..., HC, C]``) + + Args: + layer: The module applied to the aggregated single-stream input. + Must accept ``[..., C]`` and return ``[..., C]``. Common choices + include ``nn.Linear``, attention layers, or MLP blocks. + hc: Number of residual streams (called *n* in the original paper). + Recommended range: [2, 16]. Larger values increase register + pressure and Triton compile time. + c: Per-stream channel dimension. + tmax: Maximum Sinkhorn-Knopp iterations for doubly stochastic + normalization of ``h_res``. Default: 20. + rms_eps: Epsilon for RMS normalization of the projection. + Default: 1e-6. + pre_eps: Additive epsilon for ``h_pre`` after sigmoid. Default: 0.0. + sinkhorn_eps: Epsilon added during Sinkhorn normalization. + Default: 1e-6. + post_mult: Scaling factor for ``h_post`` after sigmoid. Default: 2.0. + phi_dtype: Dtype for the projection matrix ``phi``. Using float16 or + bfloat16 enables Tensor Core acceleration. Default: torch.float16. + allow_fp32: If True, accept FP32 input tensors. Note that FP32 mode + does **not** use Tensor Cores and will be slower. Default: False. + + Learnable Parameters: + - **phi** ``[HC*C, HC*HC + 2*HC]`` -- Projection matrix for computing + routing coefficients from flattened stream tokens. + - **b** ``[HC*HC + 2*HC]`` -- Bias for routing logits (float32). + - **alpha_pre** (scalar) -- Scales pre-routing logits before sigmoid. + - **alpha_post** (scalar) -- Scales post-routing logits before sigmoid. + - **alpha_res** (scalar) -- Scales residual logits before Sinkhorn. + + Example:: + + import torch + import torch.nn as nn + from liger_kernel.transformers import LigerMHC + + # Wrap a linear layer with 4 residual streams of dimension 256 + layer = nn.Linear(256, 256, bias=False, device="cuda", dtype=torch.bfloat16) + mhc = LigerMHC(layer, hc=4, c=256, phi_dtype=torch.bfloat16).cuda() + + # Input: [batch, seq_len, num_streams, channels] + x = torch.randn(2, 128, 4, 256, device="cuda", dtype=torch.bfloat16) + out = mhc(x) # shape: [2, 128, 4, 256] + + # In a transformer block (pseudocode): + # x = mhc_attn(x) # attention wrapped in LigerMHC + # x = mhc_ffn(x) # FFN wrapped in LigerMHC + """ + + def __init__( + self, + layer: nn.Module, + *, + hc: int, + c: int, + tmax: int = 20, + rms_eps: float = 1e-6, + pre_eps: float = 0.0, + sinkhorn_eps: float = 1e-6, + post_mult: float = 2.0, + phi_dtype: torch.dtype = torch.float16, + allow_fp32: bool = False, + ): + super().__init__() + self.layer = layer + # hc: number of residual streams (n in the paper) + self.hc = int(hc) + self.c = int(c) + + if hc > 16: + warnings.warn( + f"hc={hc} exceeds recommended range [2, 16]. " + "Large values may cause register pressure and increased compile time.", + stacklevel=2, + ) + self.tmax = int(tmax) + self.rms_eps = float(rms_eps) + self.pre_eps = float(pre_eps) + self.sinkhorn_eps = float(sinkhorn_eps) + self.post_mult = float(post_mult) + self.allow_fp32 = bool(allow_fp32) + + m = hc * hc + 2 * hc + k = hc * c + + try: + layer_device = next(self.layer.parameters()).device + except StopIteration: + layer_device = torch.device("cpu") + + # Note: for best speed, keep phi in BF16/FP16 to enable tensor-core matmul in Triton. + self.phi = nn.Parameter(torch.randn(k, m, dtype=phi_dtype, device=layer_device) * 0.02) + self.b = nn.Parameter(torch.zeros(m, dtype=torch.float32, device=layer_device)) + self.alpha_pre = nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=layer_device)) + self.alpha_post = nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=layer_device)) + self.alpha_res = nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=layer_device)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + x: [..., HC, C] (BF16/FP16 recommended; FP32 allowed if allow_fp32=True) + returns: [..., HC, C] + """ + if x.shape[-2] != self.hc or x.shape[-1] != self.c: + raise ValueError(f"Expected x.shape[-2:]=[{self.hc}, {self.c}], got {list(x.shape[-2:])}") + + h_pre, h_post, h_res = liger_mhc_coeffs( + x, + self.phi, + self.b, + self.alpha_pre, + self.alpha_post, + self.alpha_res, + allow_fp32=self.allow_fp32, + tmax=self.tmax, + rms_eps=self.rms_eps, + pre_eps=self.pre_eps, + sinkhorn_eps=self.sinkhorn_eps, + post_mult=self.post_mult, + ) + x_in = liger_mhc_pre(x, h_pre) # [..., C] + layer_dtype = x_in.dtype + for param in self.layer.parameters(recurse=True): + layer_dtype = param.dtype + break + if x_in.dtype != layer_dtype: + x_in = x_in.to(layer_dtype) + f_out = self.layer(x_in) # [..., C] + x_out = liger_mhc_post_res(x, f_out, h_post, h_res) # [..., HC, C] + return x_out + + def extra_repr(self) -> str: + return f"hc={self.hc}, c={self.c}, tmax={self.tmax}" diff --git a/src/liger_kernel/transformers/model/__init__.py b/src/liger_kernel/transformers/model/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/liger_kernel/transformers/model/exaone4.py b/src/liger_kernel/transformers/model/exaone4.py new file mode 100755 index 0000000000000000000000000000000000000000..c1fb863b12c6b4b08608202d9d99af5870546f83 --- /dev/null +++ b/src/liger_kernel/transformers/model/exaone4.py @@ -0,0 +1,139 @@ +from typing import List +from typing import Optional +from typing import Union + +import torch + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast + + +def lce_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, +) -> LigerCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ````python + >>> from transformers import AutoTokenizer, Exaone4ForCausalLM + + >>> model = Exaone4ForCausalLM.from_pretrained("LGAI-EXAONE/EXAONE-4.0-1.2B") + >>> tokenizer = AutoTokenizer.from_pretrained("LGAI-EXAONE/EXAONE-4.0-1.2B") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + # Remove output-control parameters that shouldn't be passed to loss functions + kwargs.pop("return_dict", None) + logits = None + loss = None + token_accuracy = None + predicted_tokens = None + + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + # By default, if in training mode, don't materialize logits + skip_logits = self.training and (labels is not None or shift_labels is not None) + + # Compute loss + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.hidden_size, + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + + else: + logits = self.lm_head(kept_hidden_states) + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + # Return custom output class with accuracy field + return LigerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/falcon_h1.py b/src/liger_kernel/transformers/model/falcon_h1.py new file mode 100755 index 0000000000000000000000000000000000000000..1529ca057983706f5801242bf253fdddc75d4d14 --- /dev/null +++ b/src/liger_kernel/transformers/model/falcon_h1.py @@ -0,0 +1,125 @@ +from typing import TYPE_CHECKING +from typing import Optional +from typing import Union + +import torch + +if TYPE_CHECKING: + from transformers.models.falcon_h1.modeling_falcon_h1 import FalconHybridMambaAttentionDynamicCache + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast + + +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional["FalconHybridMambaAttentionDynamicCache"] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, +) -> Union[tuple, LigerCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, FalconH1ForCausalLM + + >>> model = FalconH1ForCausalLM.from_pretrained("...") + >>> tokenizer = AutoTokenizer.from_pretrained("...") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + token_accuracy = None + predicted_tokens = None + + # if in training mode, don't materialize logits + if skip_logits and labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + # By default, if in training mode, don't materialize logits + skip_logits = self.training and labels is not None + + # Compute loss + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.hidden_size, + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + else: + logits = self.lm_head(kept_hidden_states) + if labels is not None or shift_labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + # Return custom output class with token_accuracy field + return LigerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/gemma.py b/src/liger_kernel/transformers/model/gemma.py new file mode 100755 index 0000000000000000000000000000000000000000..7fcdc9e282fe91a7d68350bfca7c27261eb748ac --- /dev/null +++ b/src/liger_kernel/transformers/model/gemma.py @@ -0,0 +1,144 @@ +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from transformers.cache_utils import Cache + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast + + +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + **kwargs, +) -> Union[Tuple, LigerCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, GemmaForCausalLM + + >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + token_accuracy = None + predicted_tokens = None + + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + # By default, if in training mode, don't materialize logits + skip_logits = self.training and (labels is not None or shift_labels is not None) + + # Compute loss + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.hidden_size, + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + else: + logits = self.lm_head(kept_hidden_states) + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output_tuple = (logits,) + outputs[1:] + if loss is not None: + output_tuple = (loss,) + output_tuple + if token_accuracy is not None: + output_tuple = output_tuple + (token_accuracy,) + if predicted_tokens is not None: + output_tuple = output_tuple + (predicted_tokens,) + return output_tuple + + # Return custom output class with token_accuracy field + return LigerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/gemma2.py b/src/liger_kernel/transformers/model/gemma2.py new file mode 100755 index 0000000000000000000000000000000000000000..5eebd22b2852712cc6b50083eaa8f8a5af8c6e5e --- /dev/null +++ b/src/liger_kernel/transformers/model/gemma2.py @@ -0,0 +1,157 @@ +import logging + +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from transformers.cache_utils import Cache + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast + +logger = logging.getLogger(__name__) + + +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + **kwargs, +) -> Union[Tuple, LigerCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, GemmaForCausalLM + + >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + + if self.training and self.config._attn_implementation != "eager": + logger.warning_once( + "It is strongly recommended to train Gemma2 models with the `eager` attention implementation " + f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`." + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + token_accuracy = None + predicted_tokens = None + + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + # By default, if in training mode, don't materialize logits + skip_logits = self.training and (labels is not None or shift_labels is not None) + + # Compute loss + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.hidden_size, + final_logit_softcapping=self.config.final_logit_softcapping, + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + + else: + logits = self.lm_head(kept_hidden_states) + if self.config.final_logit_softcapping is not None: + logits = logits / self.config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.config.final_logit_softcapping + + loss = None + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.vocab_size, + **kwargs, + ) + + if not return_dict: + output_tuple = (logits,) + outputs[1:] + output_tuple = (loss,) + output_tuple if loss is not None else output_tuple + output_tuple = output_tuple + (token_accuracy,) if token_accuracy is not None else output_tuple + output_tuple = output_tuple + (predicted_tokens,) if predicted_tokens is not None else output_tuple + return output_tuple + + # Return custom output class with token_accuracy field + return LigerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/gemma3.py b/src/liger_kernel/transformers/model/gemma3.py new file mode 100755 index 0000000000000000000000000000000000000000..74aae960070d49cdc5dabd586874b6d23d103247 --- /dev/null +++ b/src/liger_kernel/transformers/model/gemma3.py @@ -0,0 +1,343 @@ +from typing import Optional +from typing import Tuple +from typing import Union + +import torch +import torch.nn as nn + +from transformers.cache_utils import Cache +from transformers.utils import logging + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast +from liger_kernel.transformers.model.output_classes import LigerGemma3CausalLMOutputWithPast + +logger = logging.get_logger(__name__) + + +def causal_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + **loss_kwargs, +) -> Union[Tuple, LigerCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Gemma3ForCausalLM + + >>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + + if self.training and self.config._attn_implementation != "eager": + logger.warning_once( + "It is strongly recommended to train Gemma3 models with the `eager` attention implementation " + f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`." + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **loss_kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + shift_labels = loss_kwargs.pop("shift_labels", None) + loss = None + logits = None + token_accuracy = None + predicted_tokens = None + + if skip_logits is None: + skip_logits = self.training and (labels is not None or shift_labels is not None) + + # Compute loss + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.hidden_size, + final_logit_softcapping=self.config.final_logit_softcapping, + **loss_kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + else: + logits = self.lm_head(kept_hidden_states) + if self.config.final_logit_softcapping is not None: + logits = logits / self.config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.config.final_logit_softcapping + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.vocab_size, + **loss_kwargs, + ) + + if not return_dict: + output_tuple = (logits,) + outputs[1:] + output_tuple = (loss,) + output_tuple if loss is not None else output_tuple + output_tuple = output_tuple + (token_accuracy,) if token_accuracy is not None else output_tuple + output_tuple = output_tuple + (predicted_tokens,) if predicted_tokens is not None else output_tuple + return output_tuple + + # Return custom output class with token_accuracy field + return LigerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) + + +def multimodal_forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + **lm_kwargs, +) -> Union[tuple, LigerGemma3CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration + + >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it") + >>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it") + + >>> messages = [ + ... { + ... "role": "system", + ... "content": [ + ... {"type": "text", "text": "You are a helpful assistant."} + ... ] + ... }, + ... { + ... "role": "user", "content": [ + ... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"}, + ... {"type": "text", "text": "Where is the cat standing?"}, + ... ] + ... }, + ... ] + + >>> inputs = processor.apply_chat_template( + ... messages, + ... tokenize=True, + ... return_dict=True, + ... return_tensors="pt", + ... add_generation_prompt=True + ... ) + >>> # Generate + >>> generate_ids = model.generate(**inputs) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to" + ``` + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + labels=labels, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **lm_kwargs, + ) + + shift_labels = lm_kwargs.pop("shift_labels", None) + hidden_states = outputs[0] + + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + loss = None + logits = None + token_accuracy = None + predicted_tokens = None + if skip_logits and labels is None: + raise ValueError("skip_logits is True, but labels is None") + + if skip_logits is None: + skip_logits = self.training and (labels is not None) + + if skip_logits: + shift_hidden_states = kept_hidden_states[..., :-1, :] + shift_labels = labels[..., 1:] + + hidden_device = shift_hidden_states.device + if attention_mask is not None: + # we use the input attention mask to shift the hidden_states and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -shift_hidden_states.shape[1] :].to(hidden_device) + shift_hidden_states = shift_hidden_states[shift_attention_mask.to(hidden_device) != 0].contiguous() + shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() + else: + shift_hidden_states = shift_hidden_states.contiguous() + shift_labels = shift_labels.contiguous() + + # Flatten hidden state + shift_hidden_states = shift_hidden_states.view(-1, self.config.text_config.hidden_size) + shift_labels = shift_labels.view(-1).to(hidden_device) + + result = LigerForCausalLMLoss( + hidden_states=shift_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=shift_labels, + hidden_size=self.config.text_config.hidden_size, + shift_labels=shift_labels, + final_logit_softcapping=getattr(self.config.text_config, "final_logit_softcapping", None), + **lm_kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + + else: + logits = self.lm_head(kept_hidden_states) + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + shift_logits = logits[..., :-1, :] + shift_labels = labels[..., 1:] + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) + shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() + else: + shift_logits = shift_logits.contiguous() + shift_labels = shift_labels.contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + + flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) + flat_labels = shift_labels.view(-1).to(shift_logits.device) + loss = loss_fct(flat_logits, flat_labels) + elif shift_labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + shift_logits = logits[..., :-1, :] + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) + shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() + else: + shift_logits = shift_logits.contiguous() + shift_labels = shift_labels.contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + + flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) + flat_labels = shift_labels.view(-1).to(shift_logits.device) + loss = loss_fct(flat_logits, flat_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + output = (loss,) + output if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + return LigerGemma3CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/glm4.py b/src/liger_kernel/transformers/model/glm4.py new file mode 100755 index 0000000000000000000000000000000000000000..cb314292a1877110b834e0ce113600a216ff21e5 --- /dev/null +++ b/src/liger_kernel/transformers/model/glm4.py @@ -0,0 +1,141 @@ +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast + + +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + **kwargs, +) -> Union[Tuple, LigerCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Glm4ForCausalLM + + >>> model = Glm4ForCausalLM.from_pretrained("THUDM/GLM-4-9B-0414") + >>> tokenizer = AutoTokenizer.from_pretrained("THUDM/GLM-4-9B-0414") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m' + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + token_accuracy = None + predicted_tokens = None + + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + # By default, if in training mode, don't materialize logits + skip_logits = self.training and (labels is not None or shift_labels is not None) + + # Compute loss + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.hidden_size, + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + + else: + logits = self.lm_head(kept_hidden_states) + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + # Return custom output class with token_accuracy field + return LigerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/glm4v.py b/src/liger_kernel/transformers/model/glm4v.py new file mode 100755 index 0000000000000000000000000000000000000000..0dd3cda7f9d67f130753b4119f5a1ecf2768d2af --- /dev/null +++ b/src/liger_kernel/transformers/model/glm4v.py @@ -0,0 +1,165 @@ +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast + + +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + mm_token_type_ids: Optional[torch.IntTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + **kwargs, +) -> Union[Tuple, LigerCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> from transformers import AutoTokenizer, Glm4vForConditionalGeneration + + >>> MODEL_PATH = "THUDM/GLM-4.1V-9B-Thinking" + >>> messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png" + }, + { + "type": "text", + "text": "describe this image" + } + ], + } + ] + >>> processor = AutoProcessor.from_pretrained(MODEL_PATH, use_fast=True) + >>> model = Glm4vForConditionalGeneration.from_pretrained( + pretrained_model_name_or_path=MODEL_PATH, + dtype=torch.bfloat16, + device_map="auto", + ) + >>> inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt" + ).to(model.device) + >>> generated_ids = model.generate(**inputs, max_new_tokens=8192) + output_text = processor.decode(generated_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False) + Got it, let's describe the image. First, there's a vintage car, specifically a Volkswagen Beetle + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + mm_token_type_ids=mm_token_type_ids, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + token_accuracy = None + predicted_tokens = None + + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + # By default, if in training mode, don't materialize logits + skip_logits = self.training and (labels is not None or shift_labels is not None) + + # Compute loss + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.hidden_size, + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + + else: + logits = self.lm_head(kept_hidden_states) + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + # Return custom output class with token_accuracy field + return LigerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/glm4v_moe.py b/src/liger_kernel/transformers/model/glm4v_moe.py new file mode 100755 index 0000000000000000000000000000000000000000..3203958f8111242933f70128e14c7e4f2565842b --- /dev/null +++ b/src/liger_kernel/transformers/model/glm4v_moe.py @@ -0,0 +1,174 @@ +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerGlm4vMoeCausalLMOutputWithPast + + +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + mm_token_type_ids: Optional[torch.IntTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, +) -> Union[Tuple, LigerGlm4vMoeCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Example: + + ```python + >>> from transformers import AutoProcessor, Glm4vMoeForConditionalGeneration + >>> import torch + + >>> MODEL_PATH = "zai-org/GLM-4.5V" + >>> messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png" + }, + { + "type": "text", + "text": "describe this image" + } + ], + } + ] + >>> processor = AutoProcessor.from_pretrained(MODEL_PATH) + >>> model = Glm4vMoeForConditionalGeneration.from_pretrained( + pretrained_model_name_or_path=MODEL_PATH, + dtype="auto", + device_map="auto", + ) + >>> inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt" + ).to(model.device) + >>> inputs.pop("token_type_ids", None) + >>> generated_ids = model.generate(**inputs, max_new_tokens=8192) + >>> output_text = processor.decode(generated_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False) + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + mm_token_type_ids=mm_token_type_ids, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + token_accuracy = None + predicted_tokens = None + + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + # By default, if in training mode, don't materialize logits + skip_logits = self.training and (labels is not None or shift_labels is not None) + + # Compute loss + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.hidden_size, + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + + else: + logits = self.lm_head(kept_hidden_states) + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + # Build output kwargs and include aux_loss only if present (depends on transformers version) + output_kwargs = dict( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=outputs.rope_deltas, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) + if hasattr(outputs, "aux_loss"): + output_kwargs["aux_loss"] = outputs.aux_loss + + # Return GLM4V MoE output with accuracy + return LigerGlm4vMoeCausalLMOutputWithPast(**output_kwargs) diff --git a/src/liger_kernel/transformers/model/gpt_oss.py b/src/liger_kernel/transformers/model/gpt_oss.py new file mode 100755 index 0000000000000000000000000000000000000000..8787fde65d0cd1a814cb18a25c8acbf72b118537 --- /dev/null +++ b/src/liger_kernel/transformers/model/gpt_oss.py @@ -0,0 +1,213 @@ +from typing import List +from typing import Optional +from typing import Union + +import torch + +from transformers.modeling_outputs import MoeModelOutputWithPast +from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerMoeCausalLMOutputWithPast + + +def lce_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + **kwargs, +) -> LigerMoeCausalLMOutputWithPast: + r""" + Forward pass for causal language modeling with Mixture of Experts (MoE) architecture using Liger Kernel optimizations. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary. Indices can be obtained using tokenizers. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. + past_key_values (`List[torch.FloatTensor]` or `Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up + sequential decoding. See `past_key_values` input for more details. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the router logits of all MoE layers. See `router_logits` under returned tensors + for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + skip_logits (`bool`, *optional*): + Whether to skip logit computation and directly compute loss. If `None`, defaults to `True` during training + when labels are provided (to save memory), and `False` during inference. + + Returns: + `LigerMoeCausalLMOutputWithPast`: An output object containing: + - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction), including the auxiliary load balancing loss. + - aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + Auxiliary load balancing loss for the sparse MoE modules. + - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + Note: logits are `None` during training when `skip_logits=True` to save memory. + - past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed): + Cached key and value projection states for faster sequential decoding. + - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for each layer) of shape + `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer. + - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax. + - router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + Router logits of the MoE layers, useful to compute the auxiliary loss and z_loss. + - token_accuracy (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + Token-level prediction accuracy. + + Example: + + ```python + >>> from transformers import AutoTokenizer, GptOssForCausalLM + >>> from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss + + >>> # Apply Liger Kernel patches for optimized performance + >>> apply_liger_kernel_to_gpt_oss() + + >>> model = GptOssForCausalLM.from_pretrained("openai/gpt-oss-20b") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Inference: Forward pass returns logits + >>> outputs = model(**inputs) + >>> outputs.logits.shape + torch.Size([1, 12, 201088]) + + >>> # Get next token prediction + >>> next_token_logits = outputs.logits[:, -1, :] + >>> predicted_token_id = next_token_logits.argmax(dim=-1) + + >>> # Training: Forward pass with labels returns loss + >>> labels = inputs.input_ids.clone() + >>> outputs = model(**inputs, labels=labels) + >>> outputs.loss + tensor(2.6454) + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + token_accuracy = None + predicted_tokens = None + + if skip_logits is None: + skip_logits = self.training and (labels is not None or shift_labels is not None) + + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.hidden_size, + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + else: # if in inference model materialize logits + logits = self.lm_head(kept_hidden_states) + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.vocab_size, + **kwargs, + ) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + return LigerMoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/hunyuan_v1.py b/src/liger_kernel/transformers/model/hunyuan_v1.py new file mode 100755 index 0000000000000000000000000000000000000000..dd5aa7a21328ee8270789faef9204f5138cfadbb --- /dev/null +++ b/src/liger_kernel/transformers/model/hunyuan_v1.py @@ -0,0 +1,137 @@ +from typing import List +from typing import Optional +from typing import Union + +import torch + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast + + +def lce_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, +) -> LigerCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, HunYuanDenseV1ForCausalLM + + >>> model = HunYuanDenseV1ForCausalLM.from_pretrained("meta-hunyuan_v1_dense/HunYuanDenseV1-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-hunyuan_v1_dense/HunYuanDenseV1-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + token_accuracy = None + predicted_tokens = None + + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + # By default, if in training mode, don't materialize logits + skip_logits = self.training and (labels is not None or shift_labels is not None) + + # Compute loss + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.hidden_size, + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + + else: + logits = self.lm_head(kept_hidden_states) + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + # Return custom output class with accuracy field + return LigerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/internvl.py b/src/liger_kernel/transformers/model/internvl.py new file mode 100755 index 0000000000000000000000000000000000000000..d9c5aa365179461e0227c8a14dba8d249d29a817 --- /dev/null +++ b/src/liger_kernel/transformers/model/internvl.py @@ -0,0 +1,160 @@ +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from transformers.utils import can_return_tuple + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerInternVLCausalLMOutputWithPast + + +# Copied from https://github.com/huggingface/transformers/blob/d888bd435d0c0eaabaabad5b33d52af518c7187c/src/transformers/models/internvl/modeling_internvl.py#L862 +@can_return_tuple +def lce_forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + image_sizes: Optional[torch.Tensor] = None, + skip_logits: Optional[bool] = None, # Added argument for liger-kernel + **lm_kwargs, # renamed from kwargs +) -> Union[Tuple, LigerInternVLCausalLMOutputWithPast]: + r""" + Example: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, AutoModelForImageTextToText + + >>> torch_device = "cuda" + >>> processor = AutoProcessor.from_pretrained("OpenGVLab/InternVL3-1B-hf") + >>> model = AutoModelForImageTextToText.from_pretrained( + ... "OpenGVLab/InternVL3-1B-hf", dtype=torch.bfloat16, device_map=torch_device + ... ) + + >>> messages = [ + ... { + ... "role": "user", + ... "content": [ + ... { + ... "type": "image", + ... "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg", + ... }, + ... { + ... "type": "image", + ... "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg", + ... }, + ... {"type": "text", "text": "These images depict two different landmarks. Can you identify them?"}, + ... ], + ... }, + ... ] + + >>> inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(torch_device) + >>> generate_ids = model.generate(**inputs, max_new_tokens=200) + >>> print(processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)) + The images depict the Statue of Liberty and the Golden Gate Bridge. + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + image_sizes=image_sizes, + **lm_kwargs, + ) + + # Copied from llava.py + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = lm_kwargs.pop("shift_labels", None) + logits = None + loss = None + token_accuracy = None + predicted_tokens = None + + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + # By default, if in training mode, don't materialize logits + skip_logits = self.training and (labels is not None or shift_labels is not None) + + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.text_config.hidden_size, + **lm_kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + + else: + logits = self.lm_head(kept_hidden_states) + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **lm_kwargs + ) + + if not return_dict: + output = (logits,) + outputs[1:] + output = (loss,) + output if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + # Return custom output class with token_accuracy field + return LigerInternVLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/llama.py b/src/liger_kernel/transformers/model/llama.py new file mode 100755 index 0000000000000000000000000000000000000000..9ad3edcb3a135a61268a34b320309f505265ab7a --- /dev/null +++ b/src/liger_kernel/transformers/model/llama.py @@ -0,0 +1,202 @@ +from typing import TYPE_CHECKING +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from torch.distributed.fsdp import FullyShardedDataParallel + +from liger_kernel.transformers.fsdp import _FSDPForwardRedirection +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast +from liger_kernel.utils import PEFT_AVAILABLE + +if TYPE_CHECKING: + from transformers.cache_utils import Cache + +if PEFT_AVAILABLE: + from peft.utils.other import ModulesToSaveWrapper + + +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + **kwargs, +) -> Union[Tuple, LigerCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + if self.config.pretraining_tp > 1: + raise Exception("Liger Kernel does not support pretraining_tp!!") + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + token_accuracy = None + predicted_tokens = None + + # if in training mode, don't materialize logits + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + # By default, if in training mode, don't materialize logits + skip_logits = self.training and (labels is not None or shift_labels is not None) + + # Compute loss + if skip_logits: + result = lce_maybe_trainable_lm_head( + self, + hidden_states=kept_hidden_states, + hidden_size=self.config.hidden_size, + labels=labels, + shift_labels=shift_labels, + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + else: + logits = self.lm_head(kept_hidden_states) + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + # Return custom output class with token_accuracy field + return LigerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) + + +def lce_maybe_trainable_lm_head(self, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs): + lm_head = self.lm_head + + # Unwrap the module if lm_head has been added as trainable module in PEFT LoRA configuration, + # i.e. listed in the modules_to_save field of LoraConfig, so the lm_head weights are read + # from the unwrapped module. + # See https://huggingface.co/docs/peft/package_reference/lora for reference. + if PEFT_AVAILABLE and isinstance(lm_head, ModulesToSaveWrapper): + lm_head = lm_head.modules_to_save.default + + # If FSDP is used and lm_head is trainable, e.g., during full fine-tuning or with LoRA, + # reading the lm_head module weights and calling the kernel must be done within FSDP forward pass + # so the module entire parameters are summoned and kept in memory during the kernel execution. + if isinstance(lm_head, FullyShardedDataParallel): + return _FSDPForwardRedirection()( + lm_head, + _liger_for_causal_lm_loss, + lm_head.module, + hidden_states, + hidden_size, + labels, + shift_labels, + **loss_kwargs, + ) + + # FSDP is not used so we can read the lm_head weights and call the kernel directly + return _liger_for_causal_lm_loss( + lm_head=self.lm_head, + hidden_states=hidden_states, + hidden_size=hidden_size, + labels=labels, + shift_labels=shift_labels, + **loss_kwargs, + ) + + +def _liger_for_causal_lm_loss(lm_head, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs): + return LigerForCausalLMLoss( + hidden_states=hidden_states, + lm_head_weight=lm_head.weight, + labels=labels, + hidden_size=hidden_size, + shift_labels=shift_labels, + **loss_kwargs, + ) diff --git a/src/liger_kernel/transformers/model/llama4.py b/src/liger_kernel/transformers/model/llama4.py new file mode 100755 index 0000000000000000000000000000000000000000..32d4986a94aa987451c2162da622adc8fa0008cb --- /dev/null +++ b/src/liger_kernel/transformers/model/llama4.py @@ -0,0 +1,124 @@ +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from transformers.cache_utils import Cache + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast + + +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, +) -> Union[Tuple, LigerCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, Llama4ForCausalLM + + >>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + token_accuracy = None + predicted_tokens = None + + # Compute loss + if self.training and (labels is not None or shift_labels is not None): + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.hidden_size, + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + + else: # if in inference mode materialize logits + logits = self.lm_head(kept_hidden_states) + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + # Return custom output class with token_accuracy field + return LigerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/llava.py b/src/liger_kernel/transformers/model/llava.py new file mode 100755 index 0000000000000000000000000000000000000000..1af92165e8b84472ac0f4ef9d4b512edcd84a4a8 --- /dev/null +++ b/src/liger_kernel/transformers/model/llava.py @@ -0,0 +1,160 @@ +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerLlavaCausalLMOutputWithPast + + +def lce_forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[int] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + image_sizes: torch.Tensor = None, + skip_logits: Optional[bool] = None, + **lm_kwargs, +) -> Union[Tuple, LigerLlavaCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, LlavaForConditionalGeneration + + >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") + >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") + + >>> prompt = "USER: \nWhat's the content of the image? ASSISTANT:" + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_new_tokens=15) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed" + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + image_sizes=image_sizes, + **lm_kwargs, + ) + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = lm_kwargs.pop("shift_labels", None) + logits = None + loss = None + token_accuracy = None + predicted_tokens = None + + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + # By default, if in training mode, don't materialize logits + skip_logits = self.training and (labels is not None or shift_labels is not None) + + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.text_config.hidden_size, + **lm_kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + + else: + logits = self.lm_head(kept_hidden_states) + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.config.text_config.vocab_size, + **lm_kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + output = (loss,) + output if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + # Return custom output class with token_accuracy field + return LigerLlavaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/loss_utils.py b/src/liger_kernel/transformers/model/loss_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..508b3583d11d479df68dcec8e6ad82463b02a21e --- /dev/null +++ b/src/liger_kernel/transformers/model/loss_utils.py @@ -0,0 +1,106 @@ +import inspect + +from typing import Optional +from typing import Tuple + +import torch +import torch.nn as nn + +import liger_kernel.transformers.functional as F + +from liger_kernel.transformers.functional import CrossEntropyOutput + + +def unpack_cross_entropy_result( + result, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + if isinstance(result, CrossEntropyOutput): + return result.loss, result.z_loss, result.token_accuracy, result.predicted_tokens + + if isinstance(result, tuple): + loss = result[0] + z_loss = result[1] if len(result) > 1 else None + token_accuracy = result[2] if len(result) > 2 else None + predicted_tokens = result[3] if len(result) > 3 else None + return loss, z_loss, token_accuracy, predicted_tokens + + return result, None, None, None + + +def fixed_fused_linear_cross_entropy( + hidden_states: torch.Tensor, + lm_head_weight: torch.Tensor, + target: torch.Tensor, + num_items_in_batch: Optional[int] = None, + ignore_index: int = -100, + final_logit_softcapping: Optional[float] = None, + accum_dtype: Optional[torch.dtype] = None, + return_token_accuracy: bool = False, + return_predicted_tokens: bool = False, + **kwargs, +): + reduction = "sum" if num_items_in_batch is not None else "mean" + result = F.liger_fused_linear_cross_entropy( + hidden_states, + lm_head_weight, + target, + reduction=reduction, + ignore_index=ignore_index, + softcap=final_logit_softcapping, + accum_dtype=accum_dtype, + return_token_accuracy=return_token_accuracy, + return_predicted_tokens=return_predicted_tokens, + **kwargs, + ) + + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + + if reduction == "sum": + loss = loss / num_items_in_batch + + if return_token_accuracy or return_predicted_tokens: + return CrossEntropyOutput(loss=loss, token_accuracy=token_accuracy, predicted_tokens=predicted_tokens) + + return loss + + +def LigerForCausalLMLoss( + hidden_states, + lm_head_weight, + labels, + hidden_size: int, + num_items_in_batch: Optional[int] = None, + ignore_index: int = -100, + shift_labels: Optional[torch.Tensor] = None, + final_logit_softcapping: Optional[float] = None, + return_token_accuracy: bool = False, + return_predicted_tokens: bool = False, + **kwargs, +): + # Filter out inapplicable kwargs to liger_fused_linear_cross_entropy + applicable_params = inspect.signature(F.liger_fused_linear_cross_entropy).parameters + kwargs = {k: v for k, v in kwargs.items() if k in applicable_params} + + # Skip upcast since intermediate values for the loss are all fp32 in kernel + if shift_labels is None: + # Shift so that token < n predict n + labels = nn.functional.pad(labels, (0, 1), value=ignore_index) + shift_labels = labels[..., 1:].contiguous() + + # Flatten the tokens + hidden_states = hidden_states.view(-1, hidden_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(hidden_states.device) + result = fixed_fused_linear_cross_entropy( + hidden_states, + lm_head_weight, + shift_labels, + num_items_in_batch, + ignore_index, + final_logit_softcapping, + return_token_accuracy=return_token_accuracy, + return_predicted_tokens=return_predicted_tokens, + **kwargs, + ) + return result diff --git a/src/liger_kernel/transformers/model/mistral.py b/src/liger_kernel/transformers/model/mistral.py new file mode 100755 index 0000000000000000000000000000000000000000..09efebf5fa14140be230a0ed67e3eb89da78a637 --- /dev/null +++ b/src/liger_kernel/transformers/model/mistral.py @@ -0,0 +1,146 @@ +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from transformers.cache_utils import Cache + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast + + +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + **kwargs, +) -> Union[Tuple, LigerCausalLMOutputWithPast]: + r""" + Copy paste Mistral's forward but replace torch cross entropy with liger fused linear cross entropy + + + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MistralForCausalLM + + >>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + loss = None + logits = None + token_accuracy = None + predicted_tokens = None + + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + skip_logits = self.training and (labels is not None or shift_labels is not None) + + # Compute loss + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.hidden_size, + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + + else: + logits = self.lm_head(kept_hidden_states) + + loss = None + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output_tuple = (logits,) + outputs[1:] + output = (loss,) + output_tuple if loss is not None else output_tuple + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + # Return custom output class with token_accuracy field + return LigerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/mixtral.py b/src/liger_kernel/transformers/model/mixtral.py new file mode 100755 index 0000000000000000000000000000000000000000..5c87746bd64af2bc627394d5177af7d53126c818 --- /dev/null +++ b/src/liger_kernel/transformers/model/mixtral.py @@ -0,0 +1,167 @@ +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerMoeCausalLMOutputWithPast + + +# Ignore copy +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + **kwargs, +) -> Union[Tuple, LigerMoeCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MixtralForCausalLM + + >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + token_accuracy = None + predicted_tokens = None + + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + # By default, if in training mode, don't materialize logits + skip_logits = self.training and (labels is not None or shift_labels is not None) + + # Compute loss + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.hidden_size, + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + + else: + logits = self.lm_head(kept_hidden_states) + + loss = None + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.vocab_size, + **kwargs, + ) + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + if not return_dict: + output_tuple = (logits,) + outputs[1:] + if output_router_logits: + output_tuple = (aux_loss,) + output_tuple + if token_accuracy is not None: + output_tuple = output_tuple + (token_accuracy,) + if predicted_tokens is not None: + output_tuple = output_tuple + (predicted_tokens,) + return (loss,) + output_tuple if loss is not None else output_tuple + + # Return custom output class with token_accuracy field + return LigerMoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits if return_dict else outputs[-1], + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/mllama.py b/src/liger_kernel/transformers/model/mllama.py new file mode 100755 index 0000000000000000000000000000000000000000..72094f77a7633ef78cbfd3c00a7e446479c3b4dd --- /dev/null +++ b/src/liger_kernel/transformers/model/mllama.py @@ -0,0 +1,149 @@ +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from transformers.cache_utils import Cache + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast + + +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.LongTensor] = None, + cross_attention_mask: Optional[torch.LongTensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + **kwargs, +) -> Union[Tuple, LigerCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MllamaForCausalLM + + >>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision") + >>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision") + + >>> prompt = "If I had to write a haiku, it would be:" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6) + >>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + >>> print(result) + If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful. + I love the idea of snowflakes gently falling, each one + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + cross_attention_states=cross_attention_states, + attention_mask=attention_mask, + position_ids=position_ids, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + token_accuracy = None + predicted_tokens = None + + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + # By default, if in training mode, don't materialize logits + skip_logits = self.training and (labels is not None or shift_labels is not None) + + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.hidden_size, + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + + else: + logits = self.lm_head(kept_hidden_states) + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + output = (loss,) + output if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + # Return custom output class with token_accuracy field + return LigerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/olmo2.py b/src/liger_kernel/transformers/model/olmo2.py new file mode 100755 index 0000000000000000000000000000000000000000..e78d7815af7d3673e6faebd0b57a17cc4c3d696f --- /dev/null +++ b/src/liger_kernel/transformers/model/olmo2.py @@ -0,0 +1,141 @@ +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast + + +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + **kwargs, +) -> Union[Tuple, LigerCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Olmo2ForCausalLM + + >>> model = Olmo2ForCausalLM.from_pretrained("allenai/Olmo2-1B-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("allenai/Olmo2-1B-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m' + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + token_accuracy = None + predicted_tokens = None + + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + # By default, if in training mode, don't materialize logits + skip_logits = self.training and (labels is not None or shift_labels is not None) + + # Compute loss + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.hidden_size, + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + + else: + logits = self.lm_head(kept_hidden_states) + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + # Return custom output class with token_accuracy field + return LigerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/olmo3.py b/src/liger_kernel/transformers/model/olmo3.py new file mode 100755 index 0000000000000000000000000000000000000000..e9d1b54a8c252748acd620ded45684cc16976d1e --- /dev/null +++ b/src/liger_kernel/transformers/model/olmo3.py @@ -0,0 +1,143 @@ +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from transformers.modeling_outputs import BaseModelOutputWithPast + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast + + +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + **kwargs, +) -> Union[Tuple, LigerCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Olmo3ForCausalLM + + >>> model = Olmo3ForCausalLM.from_pretrained("allenai/Olmo-3-7B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("allenai/Olmo-3-7B-Instruct") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m' + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + token_accuracy = None + predicted_tokens = None + + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + # By default, if in training mode, don't materialize logits + skip_logits = self.training and (labels is not None or shift_labels is not None) + + # Compute loss + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.hidden_size, + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + + else: + logits = self.lm_head(kept_hidden_states) + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + # Return custom output class with token_accuracy field + return LigerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/output_classes.py b/src/liger_kernel/transformers/model/output_classes.py new file mode 100755 index 0000000000000000000000000000000000000000..f6b768c5065ba1a97a6c384500d39b3afd395e15 --- /dev/null +++ b/src/liger_kernel/transformers/model/output_classes.py @@ -0,0 +1,173 @@ +""" +Custom output classes for Liger-Kernel that extend transformers' ModelOutput classes +with optional token accuracy field. +""" + +from dataclasses import dataclass +from typing import Optional + +import torch + +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.modeling_outputs import MoeCausalLMOutputWithPast + +# The following model-specific outputs are optional and depend on the installed +# transformers version. Guard their imports so our module remains importable +# even when those models are not available in the environment. +try: + from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast as _Gemma3CausalLMOutputWithPast +except Exception: + _Gemma3CausalLMOutputWithPast = None + +try: + from transformers.models.glm4v_moe.modeling_glm4v_moe import ( + Glm4vMoeCausalLMOutputWithPast as _Glm4vMoeCausalLMOutputWithPast, + ) +except Exception: + _Glm4vMoeCausalLMOutputWithPast = None + +try: + from transformers.models.internvl.modeling_internvl import ( + InternVLCausalLMOutputWithPast as _InternVLCausalLMOutputWithPast, + ) +except Exception: + _InternVLCausalLMOutputWithPast = None + +try: + from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast as _LlavaCausalLMOutputWithPast +except Exception: + _LlavaCausalLMOutputWithPast = None + +try: + from transformers.models.paligemma.modeling_paligemma import ( + PaliGemmaCausalLMOutputWithPast as _PaliGemmaCausalLMOutputWithPast, + ) +except Exception: + _PaliGemmaCausalLMOutputWithPast = None + +try: + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VLCausalLMOutputWithPast as _Qwen2_5_VLCausalLMOutputWithPast, + ) +except Exception: + _Qwen2_5_VLCausalLMOutputWithPast = None + +try: + from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + Qwen2VLCausalLMOutputWithPast as _Qwen2VLCausalLMOutputWithPast, + ) +except Exception: + _Qwen2VLCausalLMOutputWithPast = None + +try: + from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLCausalLMOutputWithPast as _Qwen3VLCausalLMOutputWithPast, + ) +except Exception: + _Qwen3VLCausalLMOutputWithPast = None + +try: + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( + Qwen3VLMoeCausalLMOutputWithPast as _Qwen3VLMoeCausalLMOutputWithPast, + ) +except Exception: + _Qwen3VLMoeCausalLMOutputWithPast = None + +try: + from transformers.models.qwen3_5.modeling_qwen3_5 import ( + Qwen3_5CausalLMOutputWithPast as _Qwen3_5CausalLMOutputWithPast, + ) +except Exception: + _Qwen3_5CausalLMOutputWithPast = None + + +@dataclass +class LigerCausalLMOutputWithPast(CausalLMOutputWithPast): + token_accuracy: Optional[torch.FloatTensor] = None + predicted_tokens: Optional[torch.LongTensor] = None + + +@dataclass +class LigerMoeCausalLMOutputWithPast(MoeCausalLMOutputWithPast): + token_accuracy: Optional[torch.FloatTensor] = None + predicted_tokens: Optional[torch.LongTensor] = None + + +if _Gemma3CausalLMOutputWithPast is not None: + + @dataclass + class LigerGemma3CausalLMOutputWithPast(_Gemma3CausalLMOutputWithPast): + token_accuracy: Optional[torch.FloatTensor] = None + predicted_tokens: Optional[torch.LongTensor] = None + + +if _Glm4vMoeCausalLMOutputWithPast is not None: + + @dataclass + class LigerGlm4vMoeCausalLMOutputWithPast(_Glm4vMoeCausalLMOutputWithPast): + token_accuracy: Optional[torch.FloatTensor] = None + predicted_tokens: Optional[torch.LongTensor] = None + + +if _LlavaCausalLMOutputWithPast is not None: + + @dataclass + class LigerLlavaCausalLMOutputWithPast(_LlavaCausalLMOutputWithPast): + token_accuracy: Optional[torch.FloatTensor] = None + predicted_tokens: Optional[torch.LongTensor] = None + + +if _InternVLCausalLMOutputWithPast is not None: + + @dataclass + class LigerInternVLCausalLMOutputWithPast(_InternVLCausalLMOutputWithPast): + token_accuracy: Optional[torch.FloatTensor] = None + predicted_tokens: Optional[torch.LongTensor] = None + + +if _PaliGemmaCausalLMOutputWithPast is not None: + + @dataclass + class LigerPaliGemmaCausalLMOutputWithPast(_PaliGemmaCausalLMOutputWithPast): + token_accuracy: Optional[torch.FloatTensor] = None + predicted_tokens: Optional[torch.LongTensor] = None + + +if _Qwen2_5_VLCausalLMOutputWithPast is not None: + + @dataclass + class LigerQwen2_5_VLCausalLMOutputWithPast(_Qwen2_5_VLCausalLMOutputWithPast): + token_accuracy: Optional[torch.FloatTensor] = None + predicted_tokens: Optional[torch.LongTensor] = None + + +if _Qwen2VLCausalLMOutputWithPast is not None: + + @dataclass + class LigerQwen2VLCausalLMOutputWithPast(_Qwen2VLCausalLMOutputWithPast): + token_accuracy: Optional[torch.FloatTensor] = None + predicted_tokens: Optional[torch.LongTensor] = None + + +if _Qwen3VLCausalLMOutputWithPast is not None: + + @dataclass + class LigerQwen3VLCausalLMOutputWithPast(_Qwen3VLCausalLMOutputWithPast): + token_accuracy: Optional[torch.FloatTensor] = None + predicted_tokens: Optional[torch.LongTensor] = None + + +if _Qwen3VLMoeCausalLMOutputWithPast is not None: + + @dataclass + class LigerQwen3VLMoeCausalLMOutputWithPast(_Qwen3VLMoeCausalLMOutputWithPast): + token_accuracy: Optional[torch.FloatTensor] = None + predicted_tokens: Optional[torch.LongTensor] = None + + +if _Qwen3_5CausalLMOutputWithPast is not None: + + @dataclass + class LigerQwen3_5CausalLMOutputWithPast(_Qwen3_5CausalLMOutputWithPast): + token_accuracy: Optional[torch.FloatTensor] = None + predicted_tokens: Optional[torch.LongTensor] = None diff --git a/src/liger_kernel/transformers/model/paligemma.py b/src/liger_kernel/transformers/model/paligemma.py new file mode 100755 index 0000000000000000000000000000000000000000..235635771a69aa6c34be15d5f7dfc5165917777f --- /dev/null +++ b/src/liger_kernel/transformers/model/paligemma.py @@ -0,0 +1,250 @@ +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from torch.nn import CrossEntropyLoss +from transformers.cache_utils import Cache +from transformers.utils import is_torchdynamo_compiling +from transformers.utils import logging + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerPaliGemmaCausalLMOutputWithPast + +logger = logging.get_logger(__name__) + + +def lce_forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + **lm_kwargs, +) -> Union[Tuple, LigerPaliGemmaCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration + + >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/PaliGemma-test-224px-hf") + >>> processor = AutoProcessor.from_pretrained("google/PaliGemma-test-224px-hf") + + >>> prompt = "answer en Where is the cow standing?" + >>> url = "https://huggingface.co/gv-hf/PaliGemma-test-224px-hf/resolve/main/cow_beach_1.png" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=30) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "answer en Where is the cow standing?\nbeach" + ```""" + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + is_training = token_type_ids is not None and labels is not None + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed + + # Merge text and images + if pixel_values is not None: + image_features = self.get_image_features(pixel_values) + + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + image_tokens_in_text = torch.sum(input_ids == self.config.image_token_index) + raise ValueError( + f"Number of images does not match number of special image tokens in the input text. " + f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " + "tokens from image embeddings." + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + # mask out pad-token-ids in labels for BC + if labels is not None and self.pad_token_id in labels: + logger.warning_once( + "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. " + "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.", + ) + labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels) + + causal_mask = self._update_causal_mask( + attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training + ) + + outputs = self.language_model.model( + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **lm_kwargs, + ) + + shift_labels = lm_kwargs.pop("shift_labels", None) + hidden_states = outputs[0] + + loss = None + logits = None + token_accuracy = None + predicted_tokens = None + + if skip_logits and labels is None: + raise ValueError("skip_logits is True, but labels is None") + + if skip_logits is None: + skip_logits = self.training and (labels is not None) + + if skip_logits: + shift_hidden_states = hidden_states[..., :-1, :] + shift_labels = labels[..., 1:] + + hidden_device = shift_hidden_states.device + + if attention_mask is not None: + # we use the input attention mask to shift the hidden_states and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -shift_hidden_states.shape[1] :].to(hidden_device) + shift_hidden_states = shift_hidden_states[shift_attention_mask.to(hidden_device) != 0].contiguous() + shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() + else: + shift_hidden_states = shift_hidden_states.contiguous() + shift_labels = shift_labels.contiguous() + + # Flatten hidden state + shift_hidden_states = shift_hidden_states.view(-1, self.config.text_config.hidden_size) + shift_labels = shift_labels.view(-1).to(hidden_device) + + # Use LigerForCausalLMLoss with accuracy support and pass already shifted labels + result = LigerForCausalLMLoss( + hidden_states=shift_hidden_states, + lm_head_weight=self.language_model.lm_head.weight, + labels=None, + shift_labels=shift_labels, + hidden_size=self.config.text_config.hidden_size, + **lm_kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + else: + logits = self.language_model.lm_head(hidden_states) + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + shift_logits = logits[..., :-1, :] + shift_labels = labels[..., 1:] + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) + shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() + else: + shift_logits = shift_logits.contiguous() + shift_labels = shift_labels.contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + + flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) + flat_labels = shift_labels.view(-1).to(shift_logits.device) + loss = loss_fct(flat_logits, flat_labels) + elif shift_labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + shift_logits = logits[..., :-1, :] + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) + shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() + else: + shift_logits = shift_logits.contiguous() + shift_labels = shift_labels.contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + + flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) + flat_labels = shift_labels.view(-1).to(shift_logits.device) + loss = loss_fct(flat_logits, flat_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + output = (loss,) + output if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + # Return PaliGemma output with token_accuracy field + return LigerPaliGemmaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/phi3.py b/src/liger_kernel/transformers/model/phi3.py new file mode 100755 index 0000000000000000000000000000000000000000..b3a9fa1f1a7aed6be543cfe2d24891f5cb5221ee --- /dev/null +++ b/src/liger_kernel/transformers/model/phi3.py @@ -0,0 +1,123 @@ +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from transformers.modeling_outputs import BaseModelOutputWithPast + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast + + +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + **kwargs, +) -> Union[Tuple, LigerCausalLMOutputWithPast]: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, Phi3ForCausalLM + + >>> model = Phi3ForCausalLM.from_pretrained("meta-phi3/Phi3-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-phi3/Phi3-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + token_accuracy = None + predicted_tokens = None + + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + # By default, if in training mode, don't materialize logits + skip_logits = self.training and (labels is not None or shift_labels is not None) + + # Compute loss + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.hidden_size, + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + else: + logits = self.lm_head(kept_hidden_states) + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output_tuple = (logits,) + outputs[1:] + output = (loss,) + output_tuple if loss is not None else output_tuple + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + # Return custom output class with token_accuracy field + return LigerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/pixtral.py b/src/liger_kernel/transformers/model/pixtral.py new file mode 100755 index 0000000000000000000000000000000000000000..c8b3b7b69d7159a6f928e083376a4c40a82314c2 --- /dev/null +++ b/src/liger_kernel/transformers/model/pixtral.py @@ -0,0 +1,4 @@ +# Pixtral vision encoder does not require a custom forward function. +# The Liger kernel optimizations for Pixtral (RMSNorm, SwiGLU, RoPE) are applied +# via class/function-level monkey patching in monkey_patch.py, which is sufficient +# since the vision encoder has no cross-entropy loss to fuse. diff --git a/src/liger_kernel/transformers/model/qwen2.py b/src/liger_kernel/transformers/model/qwen2.py new file mode 100755 index 0000000000000000000000000000000000000000..6d43caadaa5856de4558bd2a49a4ab1e55faf8cd --- /dev/null +++ b/src/liger_kernel/transformers/model/qwen2.py @@ -0,0 +1,260 @@ +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import CausalLMOutputWithPast + +from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast + + +def lce_forward_deprecated( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + skip_logits: Optional[bool] = None, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Copy paste Qwen2's forward but replace torch cross entropy with liger fused linear cross entropy + + + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = Qwen2ForCausalLM.from_pretrained("Qwen/Qwen2-1.5B") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + loss = None + logits = None + + if skip_logits and labels is None: + raise ValueError("skip_logits is True, but labels is None") + + if skip_logits is None: + # By default, if in training mode, don't materialize logits + skip_logits = self.training and labels is not None + + if self.training and (labels is not None): + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # flatten tokens + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + + lce = LigerFusedLinearCrossEntropyLoss() + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + + else: + logits = self.lm_head(hidden_states) + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + **kwargs, +) -> Union[Tuple, LigerCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen2ForCausalLM + + >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + token_accuracy = None + predicted_tokens = None + + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + # By default, if in training mode, don't materialize logits + skip_logits = self.training and (labels is not None or shift_labels is not None) + + # Compute loss + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.hidden_size, + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + + else: + logits = self.lm_head(kept_hidden_states) + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output_tuple = (logits,) + outputs[1:] + output = (loss,) + output_tuple if loss is not None else output_tuple + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + # Return custom output class with token accuracy field + return LigerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/qwen2_5_vl.py b/src/liger_kernel/transformers/model/qwen2_5_vl.py new file mode 100755 index 0000000000000000000000000000000000000000..ac4aae51cfcb799fdec89bf1a76413c04ce33123 --- /dev/null +++ b/src/liger_kernel/transformers/model/qwen2_5_vl.py @@ -0,0 +1,186 @@ +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from packaging import version +from transformers import __version__ as transformers_version +from transformers.utils import can_return_tuple + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerQwen2_5_VLCausalLMOutputWithPast + +_TRANSFORMERS_V5_OR_LATER = version.parse(transformers_version) >= version.parse("5.0.0") + + +def _get_hidden_size(config) -> int: + """Get hidden_size from Qwen2.5VLConfig in a version-aware manner.""" + if _TRANSFORMERS_V5_OR_LATER: + return config.text_config.hidden_size + return config.hidden_size + + +def _get_vocab_size(config) -> int: + """Get vocab_size from Qwen2.5VLConfig in a version-aware manner.""" + if _TRANSFORMERS_V5_OR_LATER: + return config.text_config.vocab_size + return config.vocab_size + + +@can_return_tuple +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + mm_token_type_ids: Optional[torch.IntTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + skip_logits: Optional[bool] = None, + **kwargs, +) -> Union[Tuple, LigerQwen2_5_VLCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)): + The tensors corresponding to the input videos. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses + [`Qwen2_5_VLImageProcessor`] for processing videos. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + + >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + mm_token_type_ids=mm_token_type_ids, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + + shift_labels = kwargs.pop("shift_labels", None) + loss = None + logits = None + token_accuracy = None + predicted_tokens = None + + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + skip_logits = self.training and (labels is not None or shift_labels is not None) + + # Compute loss + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=_get_hidden_size(self.config), + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + else: + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=_get_vocab_size(self.config), + ) + + if not return_dict: + output_tuple = (logits,) + outputs[1:] + output = (loss,) + output_tuple if loss is not None else output_tuple + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + # Return Qwen2.5-VL output with token accuracy + return LigerQwen2_5_VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=outputs.rope_deltas, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/qwen2_vl.py b/src/liger_kernel/transformers/model/qwen2_vl.py new file mode 100755 index 0000000000000000000000000000000000000000..b51600a2eddedc26d7d41afce628cc34a066f2be --- /dev/null +++ b/src/liger_kernel/transformers/model/qwen2_vl.py @@ -0,0 +1,182 @@ +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from packaging import version +from transformers import __version__ as transformers_version +from transformers.utils import can_return_tuple + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerQwen2VLCausalLMOutputWithPast + +_TRANSFORMERS_V5_OR_LATER = version.parse(transformers_version) >= version.parse("5.0.0") + + +def _get_hidden_size(config) -> int: + """Get hidden_size from Qwen2VLConfig in a version-aware manner.""" + if _TRANSFORMERS_V5_OR_LATER: + return config.text_config.hidden_size + return config.hidden_size + + +def _get_vocab_size(config) -> int: + """Get vocab_size from Qwen2VLConfig in a version-aware manner.""" + if _TRANSFORMERS_V5_OR_LATER: + return config.text_config.vocab_size + return config.vocab_size + + +@can_return_tuple +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + mm_token_type_ids: Optional[torch.IntTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + skip_logits: Optional[bool] = None, + **kwargs, +) -> Union[Tuple, LigerQwen2VLCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)): + The tensors corresponding to the input videos. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses + [`Qwen2VLImageProcessor`] for processing videos. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration + + >>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + mm_token_type_ids=mm_token_type_ids, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + + shift_labels = kwargs.pop("shift_labels", None) + loss = None + logits = None + token_accuracy = None + predicted_tokens = None + + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + skip_logits = self.training and (labels is not None or shift_labels is not None) + + # Compute loss + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=_get_hidden_size(self.config), + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + else: + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=_get_vocab_size(self.config), + ) + + if not return_dict: + output_tuple = (logits,) + outputs[1:] + output = (loss,) + output_tuple if loss is not None else output_tuple + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + # Return Qwen2VL output with token accuracy + return LigerQwen2VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=outputs.rope_deltas, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/qwen3.py b/src/liger_kernel/transformers/model/qwen3.py new file mode 100755 index 0000000000000000000000000000000000000000..5b64fb90ccaa933549360662aba9070beae33156 --- /dev/null +++ b/src/liger_kernel/transformers/model/qwen3.py @@ -0,0 +1,139 @@ +from typing import List +from typing import Optional +from typing import Union + +import torch + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast + + +def lce_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, +) -> LigerCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen3ForCausalLM + + >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + # Remove output-control parameters that shouldn't be passed to loss functions + kwargs.pop("return_dict", None) + logits = None + loss = None + token_accuracy = None + predicted_tokens = None + + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + # By default, if in training mode, don't materialize logits + skip_logits = self.training and (labels is not None or shift_labels is not None) + + # Compute loss + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.hidden_size, + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + + else: + logits = self.lm_head(kept_hidden_states) + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + # Return custom output class with accuracy field + return LigerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/qwen3_5.py b/src/liger_kernel/transformers/model/qwen3_5.py new file mode 100755 index 0000000000000000000000000000000000000000..b94304b653c0429aba3d3c020b13b896496b5466 --- /dev/null +++ b/src/liger_kernel/transformers/model/qwen3_5.py @@ -0,0 +1,256 @@ +from typing import List +from typing import Optional +from typing import Union + +import torch + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast +from liger_kernel.transformers.model.output_classes import LigerQwen3_5CausalLMOutputWithPast + + +def lce_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + **kwargs, +) -> LigerCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen3_5ForCausalLM + + >>> model = Qwen3_5ForCausalLM.from_pretrained("Qwen/Qwen3.5-9B") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-9B") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + return_dict = kwargs.pop("return_dict", None) + if return_dict is None: + return_dict = self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + token_accuracy = None + predicted_tokens = None + + if skip_logits is None: + skip_logits = self.training and (labels is not None or shift_labels is not None) + + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.hidden_size, + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + else: + logits = self.lm_head(kept_hidden_states) + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + return LigerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) + + +def lce_forward_for_multimodal( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + mm_token_type_ids: Optional[torch.IntTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + **kwargs, +) -> Union[tuple, LigerQwen3_5CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + + Example: + + ```python + >>> from transformers import AutoProcessor, Qwen3_5ForConditionalGeneration + + >>> model = Qwen3_5ForConditionalGeneration.from_pretrained("Qwen/Qwen3-VL-8B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-8B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg", + }, + {"type": "text", "text": "Describe the image."}, + ], + } + ] + + >>> inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt" + ) + + >>> # Generate + >>> generated_ids = model.generate(**inputs, max_new_tokens=1024) + >>> generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] + >>> output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + >>> print(output_text) + ``` + """ + return_dict = kwargs.pop("return_dict", None) + if return_dict is None: + return_dict = self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + mm_token_type_ids=mm_token_type_ids, + **kwargs, + ) + + hidden_states = outputs[0] + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + token_accuracy = None + predicted_tokens = None + + if skip_logits is None: + skip_logits = self.training and (labels is not None or shift_labels is not None) + + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.text_config.hidden_size, + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + else: + logits = self.lm_head(kept_hidden_states) + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.config.text_config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + return LigerQwen3_5CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=outputs.rope_deltas, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/qwen3_5_moe.py b/src/liger_kernel/transformers/model/qwen3_5_moe.py new file mode 100755 index 0000000000000000000000000000000000000000..e93a3d02ad4774665222a3d95ad97638dc10a51a --- /dev/null +++ b/src/liger_kernel/transformers/model/qwen3_5_moe.py @@ -0,0 +1,157 @@ +from typing import TYPE_CHECKING +from typing import List +from typing import Optional +from typing import Union + +import torch + +from transformers.modeling_outputs import MoeModelOutputWithPast + +if TYPE_CHECKING: + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import load_balancing_loss_func + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerMoeCausalLMOutputWithPast + + +def lce_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + mm_token_type_ids: Optional[torch.IntTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, +) -> LigerMoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoModelForCausalLM, AutoTokenizer + + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-35B-A3B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-35B-A3B-Instruct") + + >>> prompt = "Give me a short introduction to large language model." + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + mm_token_type_ids=mm_token_type_ids, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + token_accuracy = None + predicted_tokens = None + + if skip_logits is None: + skip_logits = self.training and (labels is not None or shift_labels is not None) + + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.hidden_size, + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + else: # if in inference model materialize logits + logits = self.lm_head(kept_hidden_states) + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.vocab_size, + **kwargs, + ) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + output = ((aux_loss,) + output) if aux_loss is not None else output + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + return LigerMoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/qwen3_moe.py b/src/liger_kernel/transformers/model/qwen3_moe.py new file mode 100755 index 0000000000000000000000000000000000000000..cee0c9ad3d5e91092cb8db5131459e39048eff71 --- /dev/null +++ b/src/liger_kernel/transformers/model/qwen3_moe.py @@ -0,0 +1,155 @@ +from typing import List +from typing import Optional +from typing import Union + +import torch + +from transformers.modeling_outputs import MoeModelOutputWithPast +from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerMoeCausalLMOutputWithPast + + +def lce_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, +) -> LigerMoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM + + >>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + token_accuracy = None + predicted_tokens = None + + if skip_logits is None: + skip_logits = self.training and (labels is not None or shift_labels is not None) + + # Compute loss + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.hidden_size, + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + else: # if in inference model materialize logits + logits = self.lm_head(kept_hidden_states) + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.vocab_size, + **kwargs, + ) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + output = ((aux_loss,) + output) if aux_loss is not None else output + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + # Return custom output class with accuracy field + return LigerMoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/qwen3_next.py b/src/liger_kernel/transformers/model/qwen3_next.py new file mode 100755 index 0000000000000000000000000000000000000000..5f6dd0062769637177d8a4968f0b24fce721a73c --- /dev/null +++ b/src/liger_kernel/transformers/model/qwen3_next.py @@ -0,0 +1,155 @@ +from typing import TYPE_CHECKING +from typing import List +from typing import Optional +from typing import Union + +import torch + +from transformers.modeling_outputs import MoeModelOutputWithPast + +if TYPE_CHECKING: + from transformers.models.qwen3_next.modeling_qwen3_next import load_balancing_loss_func + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerMoeCausalLMOutputWithPast + + +def lce_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, +) -> LigerMoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoModelForCausalLM, AutoTokenizer + + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Next-80B-A3B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Next-80B-A3B-Instruct") + + >>> prompt = "Give me a short introduction to large language model." + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + token_accuracy = None + predicted_tokens = None + + if skip_logits is None: + skip_logits = self.training and (labels is not None or shift_labels is not None) + + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.hidden_size, + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + else: # if in inference model materialize logits + logits = self.lm_head(kept_hidden_states) + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.vocab_size, + **kwargs, + ) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + output = ((aux_loss,) + output) if aux_loss is not None else output + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + return LigerMoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/qwen3_vl.py b/src/liger_kernel/transformers/model/qwen3_vl.py new file mode 100755 index 0000000000000000000000000000000000000000..83738ebff14d7666cb9ddfe2baa7ec3a951fb7b0 --- /dev/null +++ b/src/liger_kernel/transformers/model/qwen3_vl.py @@ -0,0 +1,155 @@ +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from transformers.utils import can_return_tuple + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerQwen3VLCausalLMOutputWithPast + + +@can_return_tuple +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + mm_token_type_ids: Optional[torch.IntTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + skip_logits: Optional[bool] = None, + **kwargs, +) -> Union[Tuple, LigerQwen3VLCausalLMOutputWithPast]: + """ + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)): + The tensors corresponding to the input videos. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses + [`Qwen2_5_VLImageProcessor`] for processing videos. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + Example: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen3VLForConditionalGeneration + >>> model = Qwen3VLForConditionalGeneration.from_pretrained("Qwen/Qwen3-VL") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL") + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + mm_token_type_ids=mm_token_type_ids, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + + shift_labels = kwargs.pop("shift_labels", None) + loss = None + logits = None + token_accuracy = None + predicted_tokens = None + + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + skip_logits = self.training and (labels is not None or shift_labels is not None) + + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.text_config.hidden_size, + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + else: + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + + if not return_dict: + output = (logits,) + outputs[1:] + output = (loss,) + output if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + return LigerQwen3VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=outputs.rope_deltas, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/qwen3_vl_moe.py b/src/liger_kernel/transformers/model/qwen3_vl_moe.py new file mode 100755 index 0000000000000000000000000000000000000000..8c0c805f68f1baa9c42bb3ddc49df05b78e86e7a --- /dev/null +++ b/src/liger_kernel/transformers/model/qwen3_vl_moe.py @@ -0,0 +1,131 @@ +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import load_balancing_loss_func +from transformers.utils import can_return_tuple + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerQwen3VLMoeCausalLMOutputWithPast + + +@can_return_tuple +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + mm_token_type_ids: Optional[torch.IntTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + skip_logits: Optional[bool] = None, + **kwargs, +) -> Union[Tuple, LigerQwen3VLMoeCausalLMOutputWithPast]: + """ + Qwen3-VL-MoE forward with fused linear cross entropy support mirroring Qwen3-VL behaviour. + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + mm_token_type_ids=mm_token_type_ids, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + + shift_labels = kwargs.pop("shift_labels", None) + loss = None + logits = None + token_accuracy = None + predicted_tokens = None + + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + skip_logits = self.training and (labels is not None or shift_labels is not None) + + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.text_config.hidden_size, + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + else: + logits = self.lm_head(hidden_states) + + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + + # Compute auxiliary load-balancing loss for MoE when requested + aux_loss = None + if kwargs.get("output_router_logits", False): + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.config.text_config.num_experts, + self.config.text_config.num_experts_per_tok, + attention_mask, + ) + # If we computed training loss, add the scaled aux loss to it + if loss is not None and aux_loss is not None: + loss = loss + self.config.text_config.router_aux_loss_coef * aux_loss.to(loss.device) + + if not return_dict: + output = (logits,) + outputs[1:] + output = (loss,) + output if loss is not None else output + output = output + (aux_loss,) if aux_loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + return LigerQwen3VLMoeCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=outputs.rope_deltas, + aux_loss=aux_loss, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) diff --git a/src/liger_kernel/transformers/model/smollm3.py b/src/liger_kernel/transformers/model/smollm3.py new file mode 100755 index 0000000000000000000000000000000000000000..3a9167f5658d7b26f10c76efbe399992dceddaf7 --- /dev/null +++ b/src/liger_kernel/transformers/model/smollm3.py @@ -0,0 +1,200 @@ +from typing import TYPE_CHECKING +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from torch.distributed.fsdp import FullyShardedDataParallel + +from liger_kernel.transformers.fsdp import _FSDPForwardRedirection +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast +from liger_kernel.utils import PEFT_AVAILABLE + +if TYPE_CHECKING: + from transformers.cache_utils import Cache + +if PEFT_AVAILABLE: + from peft.utils.other import ModulesToSaveWrapper + + +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + **kwargs, +) -> Union[Tuple, LigerCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Smollm3ForCausalLM + + >>> model = Smollm3ForCausalLM.from_pretrained("HuggingFaceTB/SmolLM3-3B") + >>> tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM3-3B") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + token_accuracy = None + predicted_tokens = None + + # if in training mode, don't materialize logits + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + # By default, if in training mode, don't materialize logits + skip_logits = self.training and (labels is not None or shift_labels is not None) + + # Compute loss + if skip_logits: + result = lce_maybe_trainable_lm_head( + self, + hidden_states=kept_hidden_states, + hidden_size=self.config.hidden_size, + labels=labels, + shift_labels=shift_labels, + **kwargs, + ) + loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) + + else: + logits = self.lm_head(kept_hidden_states) + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output_tuple = (logits,) + outputs[1:] + output = (loss,) + output_tuple if loss is not None else output_tuple + output = output + (token_accuracy,) if token_accuracy is not None else output + output = output + (predicted_tokens,) if predicted_tokens is not None else output + return output + + # Return custom output class with token_accuracy field + return LigerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + token_accuracy=token_accuracy, + predicted_tokens=predicted_tokens, + ) + + +def lce_maybe_trainable_lm_head(self, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs): + lm_head = self.lm_head + + # Unwrap the module if lm_head has been added as trainable module in PEFT LoRA configuration, + # i.e. listed in the modules_to_save field of LoraConfig, so the lm_head weights are read + # from the unwrapped module. + # See https://huggingface.co/docs/peft/package_reference/lora for reference. + if PEFT_AVAILABLE and isinstance(lm_head, ModulesToSaveWrapper): + lm_head = lm_head.modules_to_save.default + + # If FSDP is used and lm_head is trainable, e.g., during full fine-tuning or with LoRA, + # reading the lm_head module weights and calling the kernel must be done within FSDP forward pass + # so the module entire parameters are summoned and kept in memory during the kernel execution. + if isinstance(lm_head, FullyShardedDataParallel): + return _FSDPForwardRedirection()( + lm_head, + _liger_for_causal_lm_loss, + lm_head.module, + hidden_states, + hidden_size, + labels, + shift_labels, + **loss_kwargs, + ) + + # FSDP is not used so we can read the lm_head weights and call the kernel directly + return _liger_for_causal_lm_loss( + lm_head=self.lm_head, + hidden_states=hidden_states, + hidden_size=hidden_size, + labels=labels, + shift_labels=shift_labels, + **loss_kwargs, + ) + + +def _liger_for_causal_lm_loss(lm_head, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs): + return LigerForCausalLMLoss( + hidden_states=hidden_states, + lm_head_weight=lm_head.weight, + labels=labels, + hidden_size=hidden_size, + shift_labels=shift_labels, + **loss_kwargs, + ) diff --git a/src/liger_kernel/transformers/model/smolvlm.py b/src/liger_kernel/transformers/model/smolvlm.py new file mode 100755 index 0000000000000000000000000000000000000000..395c0c95770de3fff463c543c60c955adf8b7d7f --- /dev/null +++ b/src/liger_kernel/transformers/model/smolvlm.py @@ -0,0 +1,158 @@ +from typing import TYPE_CHECKING +from typing import Optional +from typing import Union + +import torch + +from transformers.models.smolvlm.modeling_smolvlm import SmolVLMCausalLMOutputWithPast +from transformers.processing_utils import Unpack +from transformers.utils.generic import can_return_tuple + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss + +if TYPE_CHECKING: + from transformers.cache_utils import Cache + from transformers.utils.generic import TransformersKwargs + + +# Forward adapted to enable fused Linear + CE without materializing logits. +# Mirrors the pattern used for other multimodal models (e.g., InternVL, LLaVA). +@can_return_tuple +def lce_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional["Cache"] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + image_hidden_states: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, # Added argument for liger-kernel + **lm_kwargs: Unpack["TransformersKwargs"], # renamed from kwargs +) -> Union[tuple, SmolVLMCausalLMOutputWithPast]: + r""" + pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*): + Mask to avoid performing attention on padding pixel indices. + image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The hidden states of the image encoder after modality projection. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or `model.image_token_id`. Tokens with indices set to `model.image_token_id` are + ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> import requests + >>> import torch + >>> from PIL import Image + >>> from io import BytesIO + + >>> from transformers import AutoProcessor, AutoModelForImageTextToText + >>> from transformers.image_utils import load_image + + >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible + >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg") + >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg") + >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg") + + >>> processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct") + >>> model = AutoModelForImageTextToText.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct", dtype=torch.bfloat16, device_map="auto") + + >>> # Create inputs + >>> messages = [ + ... { + ... "role": "user", + ... "content": [ + ... {"type": "video", "path": path/to/video}, + ... {"type": "text", "text": "What is happening in this video?"}, + ... ] + ... } + ... ] + + >>> inputs = processor.apply_chat_template([messages], add_generation_prompt=True) + + >>> # Generate + >>> generated_ids = model.generate(**inputs, max_new_tokens=256) + >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True) + + >>> print(generated_texts) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + image_hidden_states=image_hidden_states, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + return_dict=True, + **lm_kwargs, + ) + + # Copied from llava.py + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = lm_kwargs.pop("shift_labels", None) + logits = None + loss = None + + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + # By default, if in training mode, don't materialize logits + skip_logits = self.training and (labels is not None or shift_labels is not None) + + if skip_logits: + loss = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.text_config.hidden_size, + **lm_kwargs, + ) + + else: + logits = self.lm_head(kept_hidden_states) + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **lm_kwargs + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return SmolVLMCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py new file mode 100755 index 0000000000000000000000000000000000000000..e5a526003eaafdf131ba9aecf49b88b1f1e857e6 --- /dev/null +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -0,0 +1,3178 @@ +import inspect +import logging + +from functools import partial +from types import MethodType +from typing import Callable +from typing import Optional + +import transformers + +from packaging import version +from transformers import PreTrainedModel + +from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss +from liger_kernel.transformers.functional import liger_cross_entropy +from liger_kernel.transformers.geglu import LigerGEGLUMLP +from liger_kernel.transformers.layer_norm import LigerLayerNorm +from liger_kernel.transformers.model.falcon_h1 import lce_forward as falcon_h1_lce_forward +from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward +from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward +from liger_kernel.transformers.model.gpt_oss import lce_forward as gpt_oss_lce_forward +from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward +from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward +from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward +from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward +from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward +from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward +from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_forward +from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb +from liger_kernel.transformers.rms_norm import LigerRMSNorm +from liger_kernel.transformers.rope import liger_rotary_pos_emb +from liger_kernel.transformers.rope import liger_rotary_pos_emb_vision +from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP +from liger_kernel.transformers.swiglu import LigerExperts +from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP +from liger_kernel.transformers.swiglu import LigerSwiGLUMLP + +try: + import peft + + PEFT_AVAILABLE = True +except ImportError: + PEFT_AVAILABLE = False + +transformer_version = version.parse(transformers.__version__) + +logger = logging.getLogger(__name__) + +MIN_SUPPORTED_TRANSFORMERS_VERSION = version.parse("4.52.0") +if transformer_version < MIN_SUPPORTED_TRANSFORMERS_VERSION: + raise ImportError( + f"liger-kernel requires transformers >= {MIN_SUPPORTED_TRANSFORMERS_VERSION}, got {transformers.__version__}. " + "Please install an older version of liger-kernel that is compatible with your transformers version." + ) + +IS_TRANSFORMERS_V5_OR_LATER = version.parse(transformers.__version__) >= version.parse("5.0.0") + + +def _bind_method_to_module(module, method_name: str, new_method: Callable): + # Binds a new method to a module instance so that self is passed as the first argument + module.__dict__[method_name] = new_method.__get__(module, module.__class__) + + +def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True, row_mode=None): + # Check if the module is a PEFT ModulesToSaveWrapper + # If it is, we need to patch the modules_to_save.default and original_modules + if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper): + module.modules_to_save.default.offset = offset + module.modules_to_save.default.casting_mode = casting_mode + module.modules_to_save.default.variance_epsilon = ( + getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps + ) + module.modules_to_save.default.in_place = in_place + module.modules_to_save.default.row_mode = row_mode + module.original_module.offset = offset + module.original_module.casting_mode = casting_mode + module.original_module.variance_epsilon = ( + getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps + ) + module.original_module.in_place = in_place + module.original_module.row_mode = row_mode + _bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward) + _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr) + _bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward) + _bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr) + _bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerRMSNorm.__name__) + _bind_method_to_module(module.original_module, "_get_name", lambda self: LigerRMSNorm.__name__) + else: + module.offset = offset + module.casting_mode = casting_mode + module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps + module.in_place = in_place + module.row_mode = row_mode + _bind_method_to_module(module, "forward", LigerRMSNorm.forward) + _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr) + _bind_method_to_module(module, "_get_name", lambda self: LigerRMSNorm.__name__) + + +def _patch_layer_norm_module(module, eps=1e-6): + # Check if the module is a PEFT ModulesToSaveWrapper + # If it is, we need to patch the modules_to_save.default and original_modules + if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper): + module.hidden_size = module.normalized_shape + _bind_method_to_module(module, "forward", LigerLayerNorm.forward) + _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr) + module.modules_to_save.default.variance_epsilon = ( + getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps + ) + module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr( + module, "normalized_shape", None + ) + module.original_module.variance_epsilon = ( + getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps + ) + module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr( + module, "normalized_shape", None + ) + _bind_method_to_module(module.modules_to_save.default, "forward", LigerLayerNorm.forward) + _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerLayerNorm.extra_repr) + _bind_method_to_module(module.original_module, "forward", LigerLayerNorm.forward) + _bind_method_to_module(module.original_module, "extra_repr", LigerLayerNorm.extra_repr) + _bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerLayerNorm.__name__) + _bind_method_to_module(module.original_module, "_get_name", lambda self: LigerLayerNorm.__name__) + else: + module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps + module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None) + _bind_method_to_module(module, "forward", LigerLayerNorm.forward) + _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr) + _bind_method_to_module(module, "_get_name", lambda self: LigerLayerNorm.__name__) + + +def _patch_swiglu_module(module, liger_module): + _bind_method_to_module(module, "forward", liger_module.forward) + _bind_method_to_module(module, "_get_name", lambda self: liger_module.__name__) + + +def _patch_geglu_module(module): + _bind_method_to_module(module, "forward", LigerGEGLUMLP.forward) + _bind_method_to_module(module, "_get_name", lambda self: LigerGEGLUMLP.__name__) + + +def apply_liger_kernel_to_granite( + rope: bool = True, + cross_entropy: bool = True, + fused_linear_cross_entropy: bool = False, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Granite 3 models + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is False. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + + + + Debugging notes: + If LigerSwiGLUMLP is OK for Llama, it should be fine for Granite, but it's not. + """ + + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.granite import modeling_granite + from transformers.models.granite.modeling_granite import GraniteModel + + if swiglu: + modeling_granite.GraniteMLP = LigerSwiGLUMLP + + if rms_norm: + modeling_granite.GraniteRMSNorm = LigerRMSNorm + + if rope: + modeling_granite.apply_rotary_pos_emb = liger_rotary_pos_emb + + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + + if fused_linear_cross_entropy: + raise NotImplementedError("LigerFusedLinearCrossEntropy is not available for Granite models.") + # NOTE: Granite model `GraniteForCausalLM.forward` scales logits each + # call, so we can't sidestep logit materialization. A bit more work + # would be needed to add a scaling term to the `LigerFusedLinearCrossEntropyFunction` + # for the logit output. + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules (e.g. GraniteRMSNorm or GraniteMLP) + + # get the base model from the model instance + base_model: GraniteModel = getattr(model, model.base_model_prefix, model) + + if rms_norm: + _patch_rms_norm_module(base_model.norm) + + for decoder_layer in base_model.layers: + if swiglu: + _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP) + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + + +def apply_liger_kernel_to_llama( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3) + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.llama import modeling_llama + from transformers.models.llama.modeling_llama import LlamaModel + + if rope: + modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb + if rms_norm: + modeling_llama.LlamaRMSNorm = LigerRMSNorm + if swiglu: + modeling_llama.LlamaMLP = LigerSwiGLUMLP + + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(llama_lce_forward, model) + else: + modeling_llama.LlamaForCausalLM.forward = llama_lce_forward + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP) + + # get the base model from the model instance + base_model: LlamaModel = getattr(model, model.base_model_prefix, model) + + if rms_norm: + _patch_rms_norm_module(base_model.norm) + + for decoder_layer in base_model.layers: + if swiglu: + _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP) + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + + +def apply_liger_kernel_to_smollm3( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace SmolLM3 model + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.smollm3 import modeling_smollm3 + from transformers.models.smollm3.modeling_smollm3 import SmolLM3Model + + if rope: + modeling_smollm3.apply_rotary_pos_emb = liger_rotary_pos_emb + if rms_norm: + modeling_smollm3.SmolLM3RMSNorm = LigerRMSNorm + if swiglu: + modeling_smollm3.SmolLM3MLP = LigerSwiGLUMLP + + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(smollm3_lce_forward, model) + else: + modeling_smollm3.SmolLM3ForCausalLM.forward = smollm3_lce_forward + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules (e.g. SmolLM3RMSNorm or SmolLM3MLP) + + # get the base model from the model instance + base_model: SmolLM3Model = getattr(model, model.base_model_prefix, model) + + if rms_norm: + _patch_rms_norm_module(base_model.norm) + + for decoder_layer in base_model.layers: + if swiglu: + _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP) + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + + +def apply_liger_kernel_to_llava( + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + model: PreTrainedModel = None, + **kwargs, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Llava models. + Due to the characteristics of LlaVa, the model must be passed to apply Liger-Kernel's patch to other models connected to LLaVa. + However, if an LM not supported by Liger-Kernel is connected to LLaVa, unexpected side effects may occur. + NOTE: Llava is not available in transformers<4.36.0 + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.llava import modeling_llava + + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(llava_lce_forward, model) + else: + modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward + + if model is not None: + text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type + text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None) + vision_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(vision_model_name, None) + + kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} + if text_liger_fn: + accept_params = inspect.signature(text_liger_fn).parameters + remain_params = set(kwargs) - (set(accept_params) & set(kwargs)) + text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params} + + if remain_params: + logger.warning( + f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n" + f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}" + ) + text_kwargs["model"] = model.model.language_model + text_liger_fn(**text_kwargs) + elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN: + logger.warning(f"{text_model_name} is not supported by Liger kernel.") + + if vision_liger_fn: + accept_params = inspect.signature(vision_liger_fn).parameters + remain_params = set(kwargs) - (set(accept_params) & set(kwargs)) + vision_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params} + + if remain_params: + logger.warning( + f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n" + f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}" + ) + vision_kwargs["model"] = model.model.vision_tower + vision_liger_fn(**vision_kwargs) + elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN: + logger.warning(f"{vision_model_name} is not supported by Liger kernel.") + + +def apply_liger_kernel_to_llama4( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, + layer_norm: bool = True, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Llama4 models. + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.llama4 import modeling_llama4 + from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM + from transformers.models.llama4.modeling_llama4 import Llama4ForConditionalGeneration + from transformers.models.llama4.modeling_llama4 import Llama4TextModel + from transformers.models.llama4.modeling_llama4 import Llama4VisionModel + + from liger_kernel.transformers.model.llama4 import lce_forward as llama4_lce_forward + + if rope: + from liger_kernel.transformers.llama4_rope import apply_liger_llama4_rope_full + + apply_liger_llama4_rope_full(modeling_llama4) + if rms_norm: + modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm + if swiglu: + modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP + + if cross_entropy: + modeling_llama4.CrossEntropyLoss = LigerCrossEntropyLoss + + if fused_linear_cross_entropy: + modeling_llama4.Llama4ForCausalLM.forward = llama4_lce_forward + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + if isinstance(model, Llama4ForConditionalGeneration): + language_model: Llama4ForCausalLM = model.language_model + vision_model: Llama4VisionModel = model.vision_model + text_model: Llama4TextModel = language_model.model + elif isinstance(model, Llama4ForCausalLM): + text_model = model.model + vision_model = None + elif isinstance(model, Llama4TextModel): + text_model = model + vision_model = None + + else: + raise ValueError(f"Unsupported Llama4 model type: {type(model)}") + + if text_model: + if rms_norm: + _patch_rms_norm_module(text_model.norm) + for decoder_layer in text_model.layers: + if swiglu: + if decoder_layer.is_moe_layer: + _patch_swiglu_module(decoder_layer.feed_forward.shared_expert, LigerSwiGLUMLP) + else: + _patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP) + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + + if vision_model: + _patch_layer_norm_module(vision_model.layernorm_pre) + _patch_layer_norm_module(vision_model.layernorm_post) + + for layer in vision_model.model.layers: + if layer_norm: + _patch_layer_norm_module(layer.input_layernorm) + _patch_layer_norm_module(layer.post_attention_layernorm) + + +def apply_liger_kernel_to_mllama( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + layer_norm: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace MLlama models. + NOTE: MLlama is not available in transformers<4.45.0 + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.mllama import modeling_mllama + from transformers.models.mllama.modeling_mllama import MllamaForCausalLM + from transformers.models.mllama.modeling_mllama import MllamaForConditionalGeneration + from transformers.models.mllama.modeling_mllama import MllamaTextModel + from transformers.models.mllama.modeling_mllama import MllamaVisionModel + + from liger_kernel.transformers.model.mllama import lce_forward as mllama_lce_forward + + if rope: + modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb + if layer_norm and model is None: + modeling_mllama.nn.LayerNorm = LigerLayerNorm + if rms_norm: + modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm + if swiglu: + modeling_mllama.MllamaTextMLP = LigerSwiGLUMLP + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(mllama_lce_forward, model) + else: + modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + + if isinstance(model, MllamaForConditionalGeneration): + language_model: MllamaForCausalLM = model.model.language_model + vision_model: MllamaVisionModel = model.model.vision_model + if isinstance(language_model, MllamaForCausalLM): + text_model: MllamaTextModel = language_model.model + else: + text_model = language_model + elif isinstance(model, MllamaForCausalLM): + text_model = model.model + vision_model = None + elif isinstance(model, MllamaTextModel): + text_model = model + vision_model = None + + else: + raise ValueError(f"Unsupported Mllama model type: {type(model)}") + + if text_model: + if rms_norm: + _patch_rms_norm_module(text_model.norm) + for decoder_layer in text_model.layers: + if swiglu: + _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP) + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + + if vision_model: + _patch_layer_norm_module(vision_model.layernorm_pre) + _patch_layer_norm_module(vision_model.layernorm_post) + + for layer in vision_model.transformer.layers: + if layer_norm: + _patch_layer_norm_module(layer.input_layernorm) + _patch_layer_norm_module(layer.post_attention_layernorm) + + for layer in vision_model.global_transformer.layers: + if layer_norm: + _patch_layer_norm_module(layer.input_layernorm) + _patch_layer_norm_module(layer.post_attention_layernorm) + + +def apply_liger_kernel_to_mistral( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Mistral models + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is False. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.mistral import modeling_mistral + from transformers.models.mistral.modeling_mistral import MistralModel + + if rope: + modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb + if rms_norm: + modeling_mistral.MistralRMSNorm = LigerRMSNorm + if cross_entropy: + modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(mistral_lce_forward, model) + else: + modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward + + if swiglu: + modeling_mistral.MistralMLP = LigerSwiGLUMLP + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + + # get the base model from the model instance + base_model: MistralModel = getattr(model, model.base_model_prefix, model) + + if rms_norm: + _patch_rms_norm_module(base_model.norm) + + for decoder_layer in base_model.layers: + if swiglu: + _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP) + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + + +def apply_liger_kernel_to_mixtral( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Mixtral models + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.mixtral import modeling_mixtral + from transformers.models.mixtral.modeling_mixtral import MixtralModel + + if rope: + modeling_mixtral.apply_rotary_pos_emb = liger_rotary_pos_emb + if rms_norm: + modeling_mixtral.MixtralRMSNorm = LigerRMSNorm + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(mixtral_lce_forward, model) + else: + modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward + if swiglu: + if IS_TRANSFORMERS_V5_OR_LATER: + modeling_mixtral.MixtralExperts = LigerExperts + else: + modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + + # get the base model from the model instance + base_model: MixtralModel = getattr(model, model.base_model_prefix, model) + + if rms_norm: + _patch_rms_norm_module(base_model.norm) + + for decoder_layer in base_model.layers: + if swiglu: + if IS_TRANSFORMERS_V5_OR_LATER: + _patch_swiglu_module(decoder_layer.mlp.experts, LigerExperts) + else: + for expert in decoder_layer.block_sparse_moe.experts: + _patch_swiglu_module(expert, LigerBlockSparseTop2MLP) + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + + +def apply_liger_kernel_to_pixtral( + rope: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Pixtral vision models. + + Note: Pixtral's vision encoder does not have a cross-entropy loss, so there is no + `fused_linear_cross_entropy` or `cross_entropy` option. The language model side of + Pixtral uses Mistral, which can be patched separately via `apply_liger_kernel_to_mistral`. + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model + has already been loaded. Default is None. + """ + from transformers.models.pixtral import modeling_pixtral + from transformers.models.pixtral.modeling_pixtral import PixtralVisionModel + + if rope: + modeling_pixtral.apply_rotary_pos_emb = liger_rotary_pos_emb + if rms_norm: + modeling_pixtral.PixtralRMSNorm = LigerRMSNorm + if swiglu: + modeling_pixtral.PixtralMLP = LigerSwiGLUMLP + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules. + if isinstance(model, PixtralVisionModel): + transformer = model.transformer + else: + raise ValueError(f"Unsupported Pixtral model type: {type(model)}") + + if rms_norm: + _patch_rms_norm_module(model.ln_pre, eps=1e-5) + + for layer in transformer.layers: + if swiglu: + _patch_swiglu_module(layer.feed_forward, LigerSwiGLUMLP) + if rms_norm: + _patch_rms_norm_module(layer.attention_norm, eps=1e-5) + _patch_rms_norm_module(layer.ffn_norm, eps=1e-5) + + +def apply_liger_kernel_to_gemma( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + geglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Gemma + (Gemma 1 and 1.1 supported, for Gemma2 please use `apply_liger_kernel_to_gemma2` ) to make GPU go burrr. + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.gemma import modeling_gemma + from transformers.models.gemma.modeling_gemma import GemmaModel + + from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma + + _patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0) + + if rope: + modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb + if rms_norm: + modeling_gemma.GemmaRMSNorm = LigerRMSNormForGemma + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + if geglu: + modeling_gemma.GemmaMLP = LigerGEGLUMLP + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(gemma_lce_forward, model) + else: + modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + + # get the base model from the model instance + base_model: GemmaModel = getattr(model, model.base_model_prefix, model) + + if rms_norm: + _patch_rms_norm_module_for_gemma(base_model.norm) + + for decoder_layer in base_model.layers: + if geglu: + _patch_geglu_module(decoder_layer.mlp) + if rms_norm: + _patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm) + _patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm) + + +def apply_liger_kernel_to_gemma2( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + geglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Gemma2 + (for Gemma1 please use `apply_liger_kernel_to_gemma`) to make GPU go burrr. + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.gemma2 import modeling_gemma2 + from transformers.models.gemma2.modeling_gemma2 import Gemma2Model + + from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma2 + + _patch_rms_norm_module_for_gemma2 = partial( + _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False + ) + + if rope: + modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb + if rms_norm: + # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109 + modeling_gemma2.Gemma2RMSNorm = LigerRMSNormForGemma2 + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(gemma2_lce_forward, model) + else: + modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward + if geglu: + modeling_gemma2.Gemma2MLP = LigerGEGLUMLP + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + + # get the base model from the model instance + base_model: Gemma2Model = getattr(model, model.base_model_prefix, model) + + if rms_norm: + _patch_rms_norm_module_for_gemma2(base_model.norm) + + for decoder_layer in base_model.layers: + if geglu: + _patch_geglu_module(decoder_layer.mlp) + if rms_norm: + _patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm) + _patch_rms_norm_module_for_gemma2(decoder_layer.post_attention_layernorm) + _patch_rms_norm_module_for_gemma2(decoder_layer.pre_feedforward_layernorm) + _patch_rms_norm_module_for_gemma2(decoder_layer.post_feedforward_layernorm) + + +def apply_liger_kernel_to_gemma3_text( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + geglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Gemma3 + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.gemma3 import modeling_gemma3 + from transformers.models.gemma3.modeling_gemma3 import Gemma3DecoderLayer + from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM + from transformers.models.gemma3.modeling_gemma3 import Gemma3TextModel + + from liger_kernel.transformers.model.gemma3 import causal_forward + from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma3 + + _patch_rms_norm_module_for_gemma3 = partial( + _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False + ) + + if rope: + modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb + + if rms_norm: + modeling_gemma3.Gemma3RMSNorm = LigerRMSNormForGemma3 + + if geglu: + modeling_gemma3.Gemma3MLP = LigerGEGLUMLP + + # Handle loss function + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(causal_forward, model) + else: + modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + + if isinstance(model, Gemma3ForCausalLM) or isinstance(model, Gemma3TextModel): + # get the base model from the model instance + base_model = model.model if isinstance(model, Gemma3ForCausalLM) else model + + if rms_norm: + _patch_rms_norm_module_for_gemma3(base_model.norm) + + for decoder_layer in base_model.layers: + decoder_layer: Gemma3DecoderLayer + if geglu: + _bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward) + if rms_norm: + _patch_rms_norm_module_for_gemma3(decoder_layer.input_layernorm) + _patch_rms_norm_module_for_gemma3(decoder_layer.post_attention_layernorm) + _patch_rms_norm_module_for_gemma3(decoder_layer.pre_feedforward_layernorm) + _patch_rms_norm_module_for_gemma3(decoder_layer.post_feedforward_layernorm) + _patch_rms_norm_module_for_gemma3(decoder_layer.self_attn.q_norm) + _patch_rms_norm_module_for_gemma3(decoder_layer.self_attn.k_norm) + + else: + raise TypeError("The model must be Gemma3ForCausalLM.") + + +def apply_liger_kernel_to_gemma3( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + layer_norm: bool = True, + rms_norm: bool = True, + geglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Gemma3 + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.gemma3 import modeling_gemma3 + from transformers.models.gemma3.modeling_gemma3 import Gemma3ForConditionalGeneration + from transformers.models.siglip import modeling_siglip + from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer + from transformers.models.siglip.modeling_siglip import SiglipVisionModel + + from liger_kernel.transformers.model.gemma3 import multimodal_forward + + _patch_rms_norm_module_for_gemma3 = partial( + _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False + ) + + if layer_norm and model is None: + modeling_siglip.nn.LayerNorm = LigerLayerNorm + + apply_liger_kernel_to_gemma3_text( + rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu + ) + + if cross_entropy: + modeling_gemma3.nn.CrossEntropyLoss = LigerCrossEntropyLoss + + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(multimodal_forward, model) + else: + modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + + if isinstance(model, Gemma3ForConditionalGeneration): + if isinstance(model.model.vision_tower, SiglipVisionModel): + vision_tower = model.model.vision_tower + + _patch_layer_norm_module(vision_tower.vision_model.post_layernorm) + + for layer in vision_tower.vision_model.encoder.layers: + layer: SiglipEncoderLayer + if layer_norm: + _patch_layer_norm_module(layer.layer_norm1) + _patch_layer_norm_module(layer.layer_norm2) + else: + raise TypeError("The vision tower must be SiglipVisionModel") + + if rms_norm: + _patch_rms_norm_module_for_gemma3(model.model.multi_modal_projector.mm_soft_emb_norm) + + apply_liger_kernel_to_gemma3_text( + rope=rope, + cross_entropy=False, + fused_linear_cross_entropy=False, + rms_norm=rms_norm, + geglu=geglu, + model=model.model.language_model, + ) + + else: + raise TypeError("The model must be Gemma3ForConditionalGeneration.") + + +def apply_liger_kernel_to_paligemma( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + layer_norm: bool = True, + rms_norm: bool = True, + geglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace PaliGemma + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + # PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model'] + + from transformers.models.gemma.modeling_gemma import GemmaForCausalLM + from transformers.models.gemma.modeling_gemma import GemmaModel + from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM + from transformers.models.gemma2.modeling_gemma2 import Gemma2Model + from transformers.models.paligemma import modeling_paligemma + from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration + from transformers.models.siglip import modeling_siglip + from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer + from transformers.models.siglip.modeling_siglip import SiglipVisionModel + + from liger_kernel.transformers.model.paligemma import lce_forward + + # The vision_tower is a SiglipVisionModel + if layer_norm and model is None: + modeling_siglip.nn.LayerNorm = LigerLayerNorm + + # SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible + # The multi_modal_projector is Linear, nothing to do + + # The language_model is GemmaForCausalLM or Gemma2ForCausalLM + apply_liger_kernel_to_gemma( + rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu + ) + apply_liger_kernel_to_gemma2( + rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu + ) + # Handle loss function + if cross_entropy: + modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(lce_forward, model) + else: + modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + + if not isinstance(model, PaliGemmaForConditionalGeneration): + raise TypeError("model have to be of type PaliGemmaForConditionalGeneration") + + vision_tower: SiglipVisionModel = model.model.vision_tower + + _patch_layer_norm_module(vision_tower.vision_model.post_layernorm) + + for layer in vision_tower.vision_model.encoder.layers: + layer: SiglipEncoderLayer + if layer_norm: + _patch_layer_norm_module(layer.layer_norm1) + _patch_layer_norm_module(layer.layer_norm2) + + language_model = model.model.language_model + + if isinstance(language_model, (GemmaForCausalLM, GemmaModel)): + apply_liger_kernel_to_gemma( + rope=rope, + cross_entropy=False, + fused_linear_cross_entropy=False, + rms_norm=rms_norm, + geglu=geglu, + model=language_model, + ) + + elif isinstance(language_model, (Gemma2ForCausalLM, Gemma2Model)): + apply_liger_kernel_to_gemma2( + rope=rope, + cross_entropy=False, + fused_linear_cross_entropy=False, + rms_norm=rms_norm, + geglu=geglu, + model=language_model, + ) + else: + raise TypeError( + "The language_model of a PaliGemma model must be either GemmaForCausalLM or Gemma2ForCausalLM." + ) + + +def apply_liger_kernel_to_qwen2( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Qwen2 models + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.qwen2 import modeling_qwen2 + from transformers.models.qwen2.modeling_qwen2 import Qwen2Model + + if rope: + modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb + if rms_norm: + modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm + + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(qwen2_lce_forward, model) + else: + modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward + + if swiglu: + modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + + # get the base model from the model instance + base_model: Qwen2Model = getattr(model, model.base_model_prefix, model) + + if rms_norm: + _patch_rms_norm_module(base_model.norm) + + for decoder_layer in base_model.layers: + if swiglu: + _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP) + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + + +def apply_liger_kernel_to_qwen3( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.qwen3 import modeling_qwen3 + from transformers.models.qwen3.modeling_qwen3 import Qwen3Model + + from liger_kernel.transformers.model.qwen3 import lce_forward as qwen3_lce_forward + + if rope: + modeling_qwen3.apply_rotary_pos_emb = liger_rotary_pos_emb + + if rms_norm: + modeling_qwen3.Qwen3RMSNorm = LigerRMSNorm + + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(qwen3_lce_forward, model) + else: + modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward + + if swiglu: + modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + + # get the base model from the model instance + base_model: Qwen3Model = getattr(model, model.base_model_prefix, model) + + if rms_norm: + _patch_rms_norm_module(base_model.norm) + for decoder_layer in base_model.layers: + if swiglu: + _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP) + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + + +def apply_liger_kernel_to_qwen3_moe( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.qwen3_moe import modeling_qwen3_moe + from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeModel + + from liger_kernel.transformers.model.qwen3_moe import lce_forward as qwen3_lce_forward + from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP + + if rope: + modeling_qwen3_moe.apply_rotary_pos_emb = liger_rotary_pos_emb + + if rms_norm: + modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm + + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(qwen3_lce_forward, model) + else: + modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward + + if swiglu: + if IS_TRANSFORMERS_V5_OR_LATER: + modeling_qwen3_moe.Qwen3MoeExperts = LigerExperts + else: + modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + + # get the base model from the model instance + base_model: Qwen3MoeModel = getattr(model, model.base_model_prefix, model) + + if rms_norm: + _patch_rms_norm_module(base_model.norm) + for decoder_layer in base_model.layers: + if swiglu: + if IS_TRANSFORMERS_V5_OR_LATER: + _patch_swiglu_module(decoder_layer.mlp.experts, LigerExperts) + else: + for mlp_expert in decoder_layer.mlp.experts: + _patch_swiglu_module(mlp_expert, LigerQwen3MoeSwiGLUMLP) + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + + +def apply_liger_kernel_to_gpt_oss( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = False, # Set to False by default since GPT-OSS has custom expert implementation + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace GPT-OSS models. + NOTE: GPT-OSS is supported in transformers >= 4.55.0 + NOTE: SwiGLU patching is disabled by default for GPT-OSS as it uses a custom expert + implementation with clamping and MXFP4 quantization. + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False. + Note: GPT-OSS uses a custom expert implementation, so SwiGLU patching is disabled by default. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + if version.parse(transformers.__version__) < version.parse("4.55.0"): + logger.warning("GPT-OSS support requires transformers >= 4.55.0") + return + + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.gpt_oss import modeling_gpt_oss + from transformers.models.gpt_oss.modeling_gpt_oss import GptOssModel + + if rope: + modeling_gpt_oss.apply_rotary_pos_emb = liger_rotary_pos_emb + + if rms_norm: + modeling_gpt_oss.GptOssRMSNorm = LigerRMSNorm + + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(gpt_oss_lce_forward, model) + else: + modeling_gpt_oss.GptOssForCausalLM.forward = gpt_oss_lce_forward + + # Note: SwiGLU patching is not implemented for GPT-OSS due to custom expert implementation + # with clamping (swiglu_limit=7.0) and MXFP4 quantization + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + + # get the base model from the model instance + base_model: GptOssModel = getattr(model, model.base_model_prefix, model) + + if rms_norm: + _patch_rms_norm_module(base_model.norm) + for decoder_layer in base_model.layers: + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + + +def apply_liger_kernel_to_qwen2_vl( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + layer_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models. + NOTE: Qwen2-VL is not supported in transformers<4.52.4 + + Args: + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + if transformer_version < version.parse("4.52.4"): + logger.warning("Qwen2-VL support is only compatible with transformers >= 4.52.4") + return + + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.qwen2_vl import modeling_qwen2_vl + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel + + from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward + + if rope: + modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb + if rms_norm: + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439 + modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm + if layer_norm and model is None: + modeling_qwen2_vl.LayerNorm = LigerLayerNorm + if cross_entropy: + modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(qwen2_vl_lce_forward, model) + else: + modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward + if swiglu: + modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + if isinstance(model, Qwen2VLForConditionalGeneration): + text_model: Qwen2VLTextModel = model.model.language_model + vision_model: Qwen2VisionTransformerPretrainedModel = model.model.visual + elif isinstance(model, Qwen2VLModel): + text_model: Qwen2VLTextModel = model.language_model + vision_model: Qwen2VisionTransformerPretrainedModel = model.visual + elif isinstance(model, Qwen2VLTextModel): + text_model: Qwen2VLTextModel = model + vision_model = None + else: + # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed. + raise TypeError( + f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}" + ) + + # Patch Qwen2VisionTransformerPretrainedModel + if vision_model is not None: + for vision_block in vision_model.blocks: + if layer_norm: + _patch_layer_norm_module(vision_block.norm1) + _patch_layer_norm_module(vision_block.norm2) + + # Patch Qwen2VisionTextModel + if text_model is not None: + if rms_norm: + _patch_rms_norm_module(text_model.norm) + for decoder_layer in text_model.layers: + if swiglu: + _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP) + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + + +def apply_liger_kernel_to_qwen2_5_vl( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Qwen2.5-VL models. + NOTE: Qwen2.5-VL is not available in transformers<4.48.2 + + Args: + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + if transformer_version < version.parse("4.52.4"): + logger.warning("Qwen2.5-VL support is only compatible with transformers >= 4.52.4") + return + + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel + + from liger_kernel.transformers.model.qwen2_5_vl import lce_forward as qwen2_5_vl_lce_forward + + if rope: + modeling_qwen2_5_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb + if rms_norm: + modeling_qwen2_5_vl.Qwen2RMSNorm = LigerRMSNorm + if cross_entropy: + modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(qwen2_5_vl_lce_forward, model) + else: + modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward + if swiglu: + modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + if isinstance(model, Qwen2_5_VLForConditionalGeneration): + text_model: Qwen2_5_VLTextModel = model.model.language_model + vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.model.visual + elif isinstance(model, Qwen2_5_VLModel): + text_model: Qwen2_5_VLTextModel = model.language_model + vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual + elif isinstance(model, Qwen2_5_VLTextModel): + text_model: Qwen2_5_VLTextModel = model + vision_model = None + else: + # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed. + raise TypeError( + f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}" + ) + + if vision_model is not None: + # Patch Qwen2_5_VisionTransformerPretrainedModel + for vision_block in vision_model.blocks: + if rms_norm: + _patch_rms_norm_module(vision_block.norm1) + _patch_rms_norm_module(vision_block.norm2) + + if text_model is not None: + if rms_norm: + _patch_rms_norm_module(text_model.norm) + for decoder_layer in text_model.layers: + if swiglu: + _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP) + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + + +def apply_liger_kernel_to_qwen3_vl( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = False, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL models. + + Args: + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.qwen3_vl import modeling_qwen3_vl + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel + + from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward + + if rope: + modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb + modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision + + if rms_norm: + modeling_qwen3_vl.Qwen3VLTextRMSNorm = LigerRMSNorm + + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(qwen3_vl_lce_forward, model) + else: + modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward + + if model is not None and rms_norm: + if isinstance(model, Qwen3VLForConditionalGeneration): + text_model: Qwen3VLTextModel = model.model.language_model + elif isinstance(model, Qwen3VLModel): + text_model: Qwen3VLTextModel = model.language_model + elif isinstance(model, Qwen3VLTextModel): + text_model = model + else: + raise TypeError( + f"Unsupported Qwen3VL model type. `model` must be `Qwen3VLForConditionalGeneration`, `Qwen3VLModel` or `Qwen3VLTextModel`. Got: {type(model)}" + ) + + _patch_qwen3_vl_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama") + + if text_model is not None: + _patch_qwen3_vl_rms_norm(text_model.norm) + for decoder_layer in text_model.layers: + _patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm) + _patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm) + self_attn = getattr(decoder_layer, "self_attn", None) + if self_attn is not None: + if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None: + _patch_qwen3_vl_rms_norm(self_attn.q_norm) + if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None: + _patch_qwen3_vl_rms_norm(self_attn.k_norm) + + +def apply_liger_kernel_to_qwen3_vl_moe( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = False, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL MoE models. + + Args: + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is False. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeModel + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextModel + + from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward + + if rope: + modeling_qwen3_vl_moe.apply_rotary_pos_emb = liger_rotary_pos_emb + modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision + + if rms_norm: + modeling_qwen3_vl_moe.Qwen3VLMoeTextRMSNorm = LigerRMSNorm + + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(qwen3_vl_moe_lce_forward, model) + else: + modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward + + if model is not None and rms_norm: + if isinstance(model, Qwen3VLMoeForConditionalGeneration): + text_model: Qwen3VLMoeTextModel = model.model.language_model + elif isinstance(model, Qwen3VLMoeModel): + text_model: Qwen3VLMoeTextModel = model.language_model + elif isinstance(model, Qwen3VLMoeTextModel): + text_model = model + else: + raise TypeError( + f"Unsupported Qwen3VLMoe model type. `model` must be `Qwen3VLMoeForConditionalGeneration`, `Qwen3VLMoeModel` or `Qwen3VLMoeTextModel`. Got: {type(model)}" + ) + + _patch_qwen3_vl_moe_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama") + + if text_model is not None: + _patch_qwen3_vl_moe_rms_norm(text_model.norm) + for decoder_layer in text_model.layers: + _patch_qwen3_vl_moe_rms_norm(decoder_layer.input_layernorm) + _patch_qwen3_vl_moe_rms_norm(decoder_layer.post_attention_layernorm) + self_attn = getattr(decoder_layer, "self_attn", None) + if self_attn is not None: + if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None: + _patch_qwen3_vl_moe_rms_norm(self_attn.q_norm) + if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None: + _patch_qwen3_vl_moe_rms_norm(self_attn.k_norm) + + +def apply_liger_kernel_to_phi3( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Phi3 models. + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.phi3 import modeling_phi3 + from transformers.models.phi3.modeling_phi3 import Phi3Model + + if rope: + modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma + if rms_norm: + modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama + if swiglu: + modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(phi3_lce_forward, model) + else: + modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + + # get the base model from the model instance + base_model: Phi3Model = getattr(model, model.base_model_prefix, model) + + if rms_norm: + _patch_rms_norm_module(base_model.norm) + + for decoder_layer in base_model.layers: + if swiglu: + _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP) + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + + +def apply_liger_kernel_to_olmo2( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace OLMO2 models. + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU Olmo2MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.olmo2 import modeling_olmo2 + from transformers.models.olmo2.modeling_olmo2 import Olmo2Model + + from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward + from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2 + + if rope: + modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb + if rms_norm: + modeling_olmo2.Olmo2RMSNorm = LigerRMSNormForOlmo2 + if swiglu: + modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(olmo2_lce_forward, model) + else: + modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + + # get the base model from the model instance + base_model: Olmo2Model = getattr(model, model.base_model_prefix, model) + + if rms_norm: + _patch_rms_norm_module(base_model.norm) + + for decoder_layer in base_model.layers: + if swiglu: + _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP) + if rms_norm: + _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False) + _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False) + + +def apply_liger_kernel_to_olmo3( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Olmo3 models. + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU to Olmo3MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.olmo3 import modeling_olmo3 + from transformers.models.olmo3.modeling_olmo3 import Olmo3Model + + from liger_kernel.transformers.model.olmo3 import lce_forward as olmo3_lce_forward + from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2 + + # Olmo3 arch is very similar to Olmo2, so we can reuse all these components in the same way. + if rope: + modeling_olmo3.apply_rotary_pos_emb = liger_rotary_pos_emb + if rms_norm: + modeling_olmo3.Olmo3RMSNorm = LigerRMSNormForOlmo2 # same as olmo2 + if swiglu: + modeling_olmo3.Olmo3MLP = LigerSwiGLUMLP + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(olmo3_lce_forward, model) + else: + modeling_olmo3.Olmo3ForCausalLM.forward = olmo3_lce_forward + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + + # get the base model from the model instance + base_model: Olmo3Model = getattr(model, model.base_model_prefix, model) + + if rms_norm: + _patch_rms_norm_module(base_model.norm) + + for decoder_layer in base_model.layers: + if swiglu: + _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP) + if rms_norm: + _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False) + _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False) + + +def apply_liger_kernel_to_glm4( + rope: bool = False, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace GLM-4 models. + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is False. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.glm4 import modeling_glm4 + from transformers.models.glm4.modeling_glm4 import Glm4Model + + from liger_kernel.transformers.model.glm4 import lce_forward as glm4_lce_forward + from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4 + + if rope: + raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.") + if rms_norm: + modeling_glm4.Glm4RMSNorm = LigerRMSNormForGlm4 + if swiglu: + modeling_glm4.Glm4MLP = LigerPhi3SwiGLUMLP + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(glm4_lce_forward, model) + else: + modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + + # get the base model from the model instance + base_model: Glm4Model = getattr(model, model.base_model_prefix, model) + + if rms_norm: + _patch_rms_norm_module(base_model.norm, in_place=False) + + for decoder_layer in base_model.layers: + if swiglu: + _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP) + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm, in_place=False) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False) + _patch_rms_norm_module(decoder_layer.post_self_attn_layernorm, in_place=False) + _patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False) + + +def apply_liger_kernel_to_glm4v( + rope: bool = False, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace GLM-4v models. + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is False. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.glm4v import modeling_glm4v + from transformers.models.glm4v.modeling_glm4v import Glm4vForConditionalGeneration + from transformers.models.glm4v.modeling_glm4v import Glm4vModel + from transformers.models.glm4v.modeling_glm4v import Glm4vTextModel + from transformers.models.glm4v.modeling_glm4v import Glm4vVisionModel + + from liger_kernel.transformers.model.glm4v import lce_forward as glm4v_lce_forward + from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4 + + if rope: + raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.") + if rms_norm: + modeling_glm4v.Glm4vRMSNorm = LigerRMSNormForGlm4 + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(glm4v_lce_forward, model) + else: + modeling_glm4v.Glm4vForConditionalGeneration.forward = glm4v_lce_forward + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + if isinstance(model, Glm4vForConditionalGeneration): + text_model: Glm4vTextModel = model.model.language_model + vision_model: Glm4vVisionModel = model.model.visual + elif isinstance(model, Glm4vModel): + text_model: Glm4vTextModel = model.language_model + vision_model: Glm4vVisionModel = model.visual + elif isinstance(model, Glm4vTextModel): + text_model: Glm4vTextModel = model + vision_model = None + else: + # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed. + raise TypeError( + f"Unsupported glm4.1v model type. `model` must be `Glm4VLForConditionalGeneration`, `Glm4vVisionModel` or `Glm4vTextModel`. Got: {type(model)}" + ) + + if vision_model is not None: + for vision_block in vision_model.blocks: + if rms_norm: + _patch_rms_norm_module(vision_block.norm1) + _patch_rms_norm_module(vision_block.norm2) + if swiglu: + _patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP) + + if text_model is not None: + if rms_norm: + _patch_rms_norm_module(text_model.norm) + for decoder_layer in text_model.layers: + if swiglu: + _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP) + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + _patch_rms_norm_module(decoder_layer.post_self_attn_layernorm) + _patch_rms_norm_module(decoder_layer.post_mlp_layernorm) + + +def apply_liger_kernel_to_glm4v_moe( + rope: bool = False, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models. + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is False. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.glm4v_moe import modeling_glm4v_moe + from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeForConditionalGeneration + from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeModel + from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeTextModel + from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeVisionModel + + from liger_kernel.transformers.model.glm4v_moe import lce_forward as glm4v_moe_lce_forward + from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4 + + if rope: + raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.") + if rms_norm: + modeling_glm4v_moe.Glm4vMoeRMSNorm = LigerRMSNormForGlm4 + modeling_glm4v_moe.Glm4vMoeTextRMSNorm = LigerRMSNormForGlm4 + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(glm4v_moe_lce_forward, model) + else: + modeling_glm4v_moe.Glm4vMoeForConditionalGeneration.forward = glm4v_moe_lce_forward + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + if isinstance(model, Glm4vMoeForConditionalGeneration): + text_model: Glm4vMoeTextModel = model.model.language_model + vision_model: Glm4vMoeVisionModel = model.model.visual + Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE + elif isinstance(model, Glm4vMoeModel): + text_model: Glm4vMoeTextModel = model.language_model + vision_model: Glm4vMoeVisionModel = model.visual + Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE + elif isinstance(model, Glm4vMoeTextModel): + text_model: Glm4vMoeTextModel = model + vision_model = None + else: + # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed. + raise TypeError( + f"Unsupported glm4v_moe model type. `model` must be `Glm4vMoeForConditionalGeneration`, `Glm4vMoeVisionModel` or `Glm4vMoeTextModel`. Got: {type(model)}" + ) + + if vision_model is not None: + _patch_rms_norm_module(vision_model.post_conv_layernorm) + _patch_rms_norm_module(vision_model.post_layernorm) + for vision_block in vision_model.blocks: + if rms_norm: + _patch_rms_norm_module(vision_block.norm1) + _patch_rms_norm_module(vision_block.norm2) + if swiglu: + _patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP) + + if text_model is not None: + if rms_norm: + _patch_rms_norm_module(text_model.norm) + for decoder_layer in text_model.layers: + if swiglu: + decoder_layer.mlp = _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP) + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + if isinstance(Glm4vMoeTextMoE, type) and isinstance(decoder_layer.mlp, Glm4vMoeTextMoE): + experts = getattr(decoder_layer.mlp, "experts", None) + if experts is not None: + for expert in experts: + _patch_swiglu_module(expert, LigerSwiGLUMLP) + if decoder_layer.mlp.shared_experts is not None: + _patch_swiglu_module(decoder_layer.mlp.shared_experts, LigerSwiGLUMLP) + for decoder_layer in text_model.layers: + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + + +def apply_liger_kernel_to_internvl( + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + layer_norm: bool = True, + model: Optional[PreTrainedModel] = None, + **kwargs, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace InternVL models. + Due to the characteristics of InternVL, the model must be passed to apply Liger-Kernel's patch to other models connected to InternVL. + However, if an LM not supported by Liger-Kernel is connected to InternVL, unexpected side effects may occur. + NOTE: InternVL is not available in transformers<4.52.1 + + Args: + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + import torch.nn as torch_nn + + from transformers.models.internvl import modeling_internvl + from transformers.models.internvl.modeling_internvl import InternVLForConditionalGeneration + from transformers.models.internvl.modeling_internvl import InternVLModel + from transformers.models.internvl.modeling_internvl import InternVLVisionLayer + from transformers.models.internvl.modeling_internvl import InternVLVisionModel + from transformers.models.internvl.modeling_internvl import InternVLVisionRMSNorm + + from liger_kernel.transformers.layer_norm import LigerLayerNorm + from liger_kernel.transformers.model.internvl import lce_forward as internvl_lce_forward + from liger_kernel.transformers.rms_norm import LigerRMSNorm + + if layer_norm and model is None: + modeling_internvl.nn.LayerNorm = LigerLayerNorm + + if cross_entropy: + logger.info("Apply liger cross entropy") + + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + if fused_linear_cross_entropy: + modeling_internvl.InternVLForConditionalGeneration.forward = internvl_lce_forward + if rms_norm: + modeling_internvl.InternVLVisionRMSNorm = LigerRMSNorm + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + if isinstance(model, InternVLForConditionalGeneration): + text_model = model.model.language_model + vision_model: InternVLVisionModel = model.model.vision_tower + elif isinstance(model, InternVLModel): + text_model = model.language_model + vision_model: InternVLVisionModel = model.vision_tower + else: + raise TypeError( + f"Unsupported internvl model type. `model` must be `InternVLForConditionalGeneration`, `InternVLModel`. Got: {type(model)}" + ) + + text_model_name = model.config.text_config.model_type + text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None) + + kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm} + if text_liger_fn: + accept_params = inspect.signature(text_liger_fn).parameters + remain_params = set(kwargs) - (set(accept_params) & set(kwargs)) + text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params} + + if remain_params: + logger.warning( + f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n" + f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}" + ) + text_kwargs["model"] = text_model + text_liger_fn(**text_kwargs) + elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN: + logger.warning(f"{text_model_name} is not supported by Liger kernel.") + + # Patch vision model RMSNorm layers + if rms_norm: + for encoder_layer in vision_model.encoder.layer: + encoder_layer: InternVLVisionLayer + if isinstance(encoder_layer.attention.q_norm, InternVLVisionRMSNorm): + _patch_rms_norm_module(encoder_layer.attention.q_norm) + if isinstance(encoder_layer.attention.k_norm, InternVLVisionRMSNorm): + _patch_rms_norm_module(encoder_layer.attention.k_norm) + + # Patch vision model LayerNorm layers + if layer_norm: + # Patch layernorm + if isinstance(vision_model.layernorm, torch_nn.LayerNorm): + _patch_layer_norm_module(vision_model.layernorm) + + # Patch encoder layers + for encoder_layer in vision_model.encoder.layer: + encoder_layer: InternVLVisionLayer + if isinstance(encoder_layer.layernorm_before, torch_nn.LayerNorm): + _patch_layer_norm_module(encoder_layer.layernorm_before) + if isinstance(encoder_layer.layernorm_after, torch_nn.LayerNorm): + _patch_layer_norm_module(encoder_layer.layernorm_after) + + +def apply_liger_kernel_to_smolvlm( + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + layer_norm: bool = True, + model: Optional[PreTrainedModel] = None, + **kwargs, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace SmolVLM models. + Due to the characteristics of SmolVLM, the model must be passed to apply Liger-Kernel's patch to other models connected to SmolVLM. + However, if an LM not supported by Liger-Kernel is connected to SmolVLM, unexpected side effects may occur. + NOTE: SmolVLM is not available in transformers<4.50.0 + + Args: + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.smolvlm import modeling_smolvlm + from transformers.models.smolvlm.modeling_smolvlm import SmolVLMEncoderLayer + from transformers.models.smolvlm.modeling_smolvlm import SmolVLMForConditionalGeneration + from transformers.models.smolvlm.modeling_smolvlm import SmolVLMModel + from transformers.models.smolvlm.modeling_smolvlm import SmolVLMVisionTransformer + + from liger_kernel.transformers.model.smolvlm import lce_forward as smolvlm_lce_forward + + # Patch LayerNorm for vision model if model is not provided (pre-initialization) + if layer_norm and model is None: + modeling_smolvlm.nn.LayerNorm = LigerLayerNorm + + if cross_entropy: + logger.info("Apply liger cross entropy") + + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(smolvlm_lce_forward, model) + else: + modeling_smolvlm.SmolVLMForConditionalGeneration.forward = smolvlm_lce_forward + if rms_norm: + modeling_smolvlm.SmolVLMRMSNorm = LigerRMSNorm + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + if isinstance(model, SmolVLMForConditionalGeneration): + text_model = model.model.text_model + vision_model: SmolVLMVisionTransformer = model.model.vision_model + elif isinstance(model, SmolVLMModel): + text_model = model.text_model + vision_model: SmolVLMVisionTransformer = model.vision_model + else: + raise TypeError( + f"Unsupported smolvlm model type. `model` must be `SmolVLMForConditionalGeneration`, `SmolVLMModel`. Got: {type(model)}" + ) + + text_model_name = model.config.text_config.model_type + text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None) + + kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm} + if text_liger_fn: + accept_params = inspect.signature(text_liger_fn).parameters + remain_params = set(kwargs) - (set(accept_params) & set(kwargs)) + text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params} + + if remain_params: + logger.warning( + f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n" + f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}" + ) + text_kwargs["model"] = text_model + text_liger_fn(**text_kwargs) + elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN: + logger.warning(f"{text_model_name} is not supported by Liger kernel.") + + # Patch vision model LayerNorm layers + if layer_norm: + # Patch post_layernorm + _patch_layer_norm_module(vision_model.post_layernorm) + + # Patch encoder layers + for encoder_layer in vision_model.encoder.layers: + encoder_layer: SmolVLMEncoderLayer + _patch_layer_norm_module(encoder_layer.layer_norm1) + _patch_layer_norm_module(encoder_layer.layer_norm2) + + +def apply_liger_kernel_to_falcon_h1( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = False, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Falcon-H1 models + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is False. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False. + swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.falcon_h1 import modeling_falcon_h1 + from transformers.models.falcon_h1.modeling_falcon_h1 import FalconH1Model + + if rope: + logger.info("Apply liger rotary pos emb.") + modeling_falcon_h1.apply_rotary_pos_emb = liger_rotary_pos_emb + if rms_norm: + logger.info("Apply liger RMSNorm") + modeling_falcon_h1.FalconH1RMSNorm = LigerRMSNorm + if swiglu: + logger.warning("LigerSwiGLUMLP is not available for Falcon-H1 models. There will be no effect.") + + if cross_entropy: + logger.info("Apply liger cross entropy") + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(falcon_h1_lce_forward, model) + else: + modeling_falcon_h1.FalconH1ForCausalLM.forward = falcon_h1_lce_forward + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP) + + # get the base model from the model instance + base_model: FalconH1Model = getattr(model, model.base_model_prefix, model) + + if rms_norm: + _patch_rms_norm_module(base_model.final_layernorm) + + for decoder_layer in base_model.layers: + if swiglu: + _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP) + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.pre_ff_layernorm) + + +def apply_liger_kernel_to_qwen3_next( + rope: bool = False, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models. + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is False. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.qwen3_next import modeling_qwen3_next + from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextForCausalLM + from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextMLP + from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextModel + from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock + + from liger_kernel.transformers.model.qwen3_next import lce_forward as qwen3_next_lce_forward + from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next + from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP + + if rope: + # It might enocunter nan issue + # modeling_qwen3_next.apply_rotary_pos_emb = liger_rotary_pos_emb + raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3Next models.") + if rms_norm: + modeling_qwen3_next.Qwen3NextRMSNorm = LigerRMSNormForQwen3Next + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + if fused_linear_cross_entropy: + if model is not None: + if isinstance(model, Qwen3NextForCausalLM): + model.forward = MethodType(qwen3_next_lce_forward, model) + else: + raise TypeError( + f" fused_linear_cross_entropy is only applicable on Qwen3NextForCausalLM. Got: {type(model)}" + ) + else: + modeling_qwen3_next.Qwen3NextForCausalLM.forward = qwen3_next_lce_forward + if swiglu: + if IS_TRANSFORMERS_V5_OR_LATER: + modeling_qwen3_next.Qwen3NextExperts = LigerExperts + else: + # Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP + modeling_qwen3_next.Qwen3NextMLP = LigerQwen3MoeSwiGLUMLP + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + if isinstance(model, (Qwen3NextForCausalLM, Qwen3NextModel)): + base_model: Qwen3NextForCausalLM = getattr(model, model.base_model_prefix, model) + else: + raise TypeError( + f"Unsupported qwen3_next model type. `model` must be `Qwen3NextForCausalLM`, `Qwen3NextModel`. Got: {type(model)}" + ) + + _patch_rms_norm_module_for_qwen3_next = partial( + _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False + ) + + if rms_norm: + _patch_rms_norm_module_for_qwen3_next(base_model.norm) + + for decoder_layer in base_model.layers: + if rms_norm: + _patch_rms_norm_module_for_qwen3_next(decoder_layer.input_layernorm) + _patch_rms_norm_module_for_qwen3_next(decoder_layer.post_attention_layernorm) + + # Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP + if swiglu: + if isinstance(decoder_layer.mlp, Qwen3NextMLP): + _patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP) + if isinstance(decoder_layer.mlp, Qwen3NextSparseMoeBlock): + _patch_swiglu_module(decoder_layer.mlp.shared_expert, LigerQwen3MoeSwiGLUMLP) + experts = getattr(decoder_layer.mlp, "experts", None) + if experts is not None: + if IS_TRANSFORMERS_V5_OR_LATER: + _patch_swiglu_module(experts, LigerExperts) + else: + for expert in experts: + _patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP) + + +def apply_liger_kernel_to_qwen3_5( + rope: bool = False, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 dense models. + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is False. + Not yet supported for Qwen3.5 due to hybrid attention (Gated DeltaNet + Gated Attention). + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.qwen3_5 import modeling_qwen3_5 + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5TextModel + + try: + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForConditionalGeneration + except ImportError: + Qwen3_5ForConditionalGeneration = None + + from liger_kernel.transformers.model.qwen3_5 import lce_forward as qwen3_5_lce_forward + from liger_kernel.transformers.model.qwen3_5 import lce_forward_for_multimodal as qwen3_5_lce_forward_for_multimodal + from liger_kernel.transformers.monkey_patch import _patch_rms_norm_module + from liger_kernel.transformers.monkey_patch import _patch_swiglu_module + from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next + from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP + + if rope: + raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3_5 models.") + + if rms_norm: + modeling_qwen3_5.Qwen3_5RMSNorm = LigerRMSNormForQwen3Next + + if cross_entropy: + from transformers.loss.loss_utils import nn + + from liger_kernel.transformers.cross_entropy import liger_cross_entropy + + nn.functional.cross_entropy = liger_cross_entropy + + if fused_linear_cross_entropy: + if model is not None: + if isinstance(model, Qwen3_5ForCausalLM): + model.forward = MethodType(qwen3_5_lce_forward, model) + elif isinstance(model, Qwen3_5ForConditionalGeneration): + model.forward = MethodType(qwen3_5_lce_forward_for_multimodal, model) + else: + raise TypeError( + f"fused_linear_cross_entropy is only applicable on Qwen3_5ForCausalLM or Qwen3_5ForConditionalGeneration. Got: {type(model)}" + ) + else: + modeling_qwen3_5.Qwen3_5ForCausalLM.forward = qwen3_5_lce_forward + if Qwen3_5ForConditionalGeneration is not None: + modeling_qwen3_5.Qwen3_5ForConditionalGeneration.forward = qwen3_5_lce_forward_for_multimodal + + if swiglu: + modeling_qwen3_5.Qwen3_5MLP = LigerQwen3MoeSwiGLUMLP + + if model is not None: + if isinstance(model, (Qwen3_5ForCausalLM, Qwen3_5TextModel)): + text_model: Qwen3_5TextModel = getattr(model, model.base_model_prefix, model) + elif Qwen3_5ForConditionalGeneration is not None and isinstance(model, Qwen3_5ForConditionalGeneration): + text_model = model.model.language_model + else: + raise TypeError(f"Unsupported qwen3_5 model type. Got: {type(model)}") + + _patch_rms_norm_module_for_qwen3_5 = partial( + _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False + ) + + if rms_norm: + _patch_rms_norm_module_for_qwen3_5(text_model.norm) + + for decoder_layer in text_model.layers: + if rms_norm: + _patch_rms_norm_module_for_qwen3_5(decoder_layer.input_layernorm) + _patch_rms_norm_module_for_qwen3_5(decoder_layer.post_attention_layernorm) + + if swiglu: + _patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP) + + +def apply_liger_kernel_to_qwen3_5_moe( + rope: bool = False, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 MoE models. + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is False. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.qwen3_5_moe import modeling_qwen3_5_moe + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeForCausalLM + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeTextModel + + from liger_kernel.transformers.model.qwen3_5_moe import lce_forward as qwen3_5_moe_lce_forward + from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next + from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP + + if rope: + raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3_5Moe models.") + if rms_norm: + modeling_qwen3_5_moe.Qwen3_5MoeRMSNorm = LigerRMSNormForQwen3Next + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + if fused_linear_cross_entropy: + if model is not None: + if isinstance(model, Qwen3_5MoeForCausalLM): + model.forward = MethodType(qwen3_5_moe_lce_forward, model) + else: + raise TypeError( + f" fused_linear_cross_entropy is only applicable on Qwen3_5MoeForCausalLM. Got: {type(model)}" + ) + else: + modeling_qwen3_5_moe.Qwen3_5MoeForCausalLM.forward = qwen3_5_moe_lce_forward + if swiglu: + modeling_qwen3_5_moe.Qwen3_5MoeExperts = LigerExperts + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + if isinstance(model, (Qwen3_5MoeForCausalLM, Qwen3_5MoeTextModel)): + base_model: Qwen3_5MoeTextModel = getattr(model, model.base_model_prefix, model) + else: + raise TypeError( + f"Unsupported qwen3_5_moe model type. `model` must be `Qwen3_5MoeForCausalLM`, `Qwen3_5MoeTextModel`. Got: {type(model)}" + ) + + _patch_rms_norm_module_for_qwen3_5_moe = partial( + _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False + ) + + if rms_norm: + _patch_rms_norm_module_for_qwen3_5_moe(base_model.norm) + + for decoder_layer in base_model.layers: + if rms_norm: + _patch_rms_norm_module_for_qwen3_5_moe(decoder_layer.input_layernorm) + _patch_rms_norm_module_for_qwen3_5_moe(decoder_layer.post_attention_layernorm) + + if swiglu: + _patch_swiglu_module(decoder_layer.mlp.shared_expert, LigerQwen3MoeSwiGLUMLP) + experts = getattr(decoder_layer.mlp, "experts", None) + if experts is not None: + _patch_swiglu_module(experts, LigerExperts) + + +def apply_liger_kernel_to_hunyuan_v1_dense( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Hunyuan v1 dense models. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.hunyuan_v1_dense import modeling_hunyuan_v1_dense + from transformers.models.hunyuan_v1_dense.modeling_hunyuan_v1_dense import HunYuanDenseV1Model + + from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_lce_forward + from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP + + if rope: + modeling_hunyuan_v1_dense.apply_rotary_pos_emb = liger_rotary_pos_emb + + if rms_norm: + modeling_hunyuan_v1_dense.HunYuanDenseV1RMSNorm = LigerRMSNorm + + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(hunyuan_v1_lce_forward, model) + else: + modeling_hunyuan_v1_dense.HunYuanDenseV1ForCausalLM.forward = hunyuan_v1_lce_forward + + if swiglu: + modeling_hunyuan_v1_dense.HunYuanDenseV1MLP = LigerHunyuanV1SwiGLUMLP + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + + # get the base model from the model instance + base_model: HunYuanDenseV1Model = getattr(model, model.base_model_prefix, model) + + if rms_norm: + _patch_rms_norm_module(base_model.norm) + for decoder_layer in base_model.layers: + if swiglu: + _patch_swiglu_module(decoder_layer.mlp, LigerHunyuanV1SwiGLUMLP) + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + + +def apply_liger_kernel_to_hunyuan_v1_moe( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.hunyuan_v1_moe import modeling_hunyuan_v1_moe + from transformers.models.hunyuan_v1_moe.modeling_hunyuan_v1_moe import HunYuanMoEV1Model + + from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_moe_lce_forward + from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP + + if rope: + modeling_hunyuan_v1_moe.apply_rotary_pos_emb = liger_rotary_pos_emb + + if rms_norm: + modeling_hunyuan_v1_moe.HunYuanMoEV1RMSNorm = LigerRMSNorm + + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(hunyuan_v1_moe_lce_forward, model) + else: + modeling_hunyuan_v1_moe.HunYuanMoEV1ForCausalLM.forward = hunyuan_v1_moe_lce_forward + + if swiglu: + if IS_TRANSFORMERS_V5_OR_LATER: + modeling_hunyuan_v1_moe.HunYuanMoEV1Experts = LigerExperts + else: + modeling_hunyuan_v1_moe.HunYuanMoEV1MLP = LigerHunyuanV1SwiGLUMLP + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + + # get the base model from the model instance + base_model: HunYuanMoEV1Model = getattr(model, model.base_model_prefix, model) + + if rms_norm: + _patch_rms_norm_module(base_model.norm) + for decoder_layer in base_model.layers: + if swiglu: + if IS_TRANSFORMERS_V5_OR_LATER: + _patch_swiglu_module(decoder_layer.mlp.experts, LigerExperts) + else: + for mlp_expert in decoder_layer.mlp.experts: + _patch_swiglu_module(mlp_expert, LigerHunyuanV1SwiGLUMLP) + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + + +def apply_liger_kernel_to_exaone4( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace EXAONE4 models. + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.exaone4 import modeling_exaone4 + from transformers.models.exaone4.modeling_exaone4 import Exaone4Model + + from liger_kernel.transformers.model.exaone4 import lce_forward as exaone4_lce_forward + + if rope: + modeling_exaone4.apply_rotary_pos_emb = liger_rotary_pos_emb + + if rms_norm: + # EXAONE4 requires in_place=False to avoid gradient issues + class Exaone4LigerRMSNorm(LigerRMSNorm): + def __init__(self, hidden_size, eps=1e-6, **kwargs): + super().__init__(hidden_size, eps, **kwargs) + self.in_place = False + + modeling_exaone4.Exaone4RMSNorm = Exaone4LigerRMSNorm + + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(exaone4_lce_forward, model) + else: + modeling_exaone4.Exaone4ForCausalLM.forward = exaone4_lce_forward + + if swiglu: + modeling_exaone4.Exaone4MLP = LigerSwiGLUMLP + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + + # get the base model from the model instance + base_model: Exaone4Model = getattr(model, model.base_model_prefix, model) + + if rms_norm: + _patch_rms_norm_module(base_model.norm, in_place=False) + for decoder_layer in base_model.layers: + if swiglu: + _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward) + if rms_norm: + _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False) + _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False) + _patch_rms_norm_module(decoder_layer.self_attn.q_norm, in_place=False) + _patch_rms_norm_module(decoder_layer.self_attn.k_norm, in_place=False) + + +# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py +MODEL_TYPE_TO_APPLY_LIGER_FN = { + "gemma": apply_liger_kernel_to_gemma, + "gemma2": apply_liger_kernel_to_gemma2, + "gemma3_text": apply_liger_kernel_to_gemma3_text, + "gemma3": apply_liger_kernel_to_gemma3, + "glm4": apply_liger_kernel_to_glm4, + "glm4v": apply_liger_kernel_to_glm4v, + "glm4v_moe": apply_liger_kernel_to_glm4v_moe, + "gpt_oss": apply_liger_kernel_to_gpt_oss, + "internvl": apply_liger_kernel_to_internvl, + "llama": apply_liger_kernel_to_llama, + "llama4_text": apply_liger_kernel_to_llama4, + "llama4": apply_liger_kernel_to_llama4, + "llava": apply_liger_kernel_to_llava, + "granite": apply_liger_kernel_to_granite, + "mllama": apply_liger_kernel_to_mllama, + "mllama_text_model": apply_liger_kernel_to_mllama, + "mistral": apply_liger_kernel_to_mistral, + "mixtral": apply_liger_kernel_to_mixtral, + "olmo2": apply_liger_kernel_to_olmo2, + "pixtral": apply_liger_kernel_to_pixtral, + "olmo3": apply_liger_kernel_to_olmo3, + "qwen2": apply_liger_kernel_to_qwen2, + "qwen3": apply_liger_kernel_to_qwen3, + "qwen3_moe": apply_liger_kernel_to_qwen3_moe, + "qwen2_vl": apply_liger_kernel_to_qwen2_vl, + "qwen2_vl_text": apply_liger_kernel_to_qwen2_vl, + "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl, + "qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl, + "qwen3_next": apply_liger_kernel_to_qwen3_next, + "qwen3_5": apply_liger_kernel_to_qwen3_5, + "qwen3_5_text": apply_liger_kernel_to_qwen3_5, + "qwen3_5_moe": apply_liger_kernel_to_qwen3_5_moe, + "qwen3_5_moe_text": apply_liger_kernel_to_qwen3_5_moe, + "qwen3_vl": apply_liger_kernel_to_qwen3_vl, + "qwen3_vl_text": apply_liger_kernel_to_qwen3_vl, + "qwen3_vl_moe": apply_liger_kernel_to_qwen3_vl_moe, + "qwen3_vl_moe_text": apply_liger_kernel_to_qwen3_vl_moe, + "smollm3": apply_liger_kernel_to_smollm3, + "phi3": apply_liger_kernel_to_phi3, + "paligemma": apply_liger_kernel_to_paligemma, + "falcon_h1": apply_liger_kernel_to_falcon_h1, + "smolvlm": apply_liger_kernel_to_smolvlm, + "hunyuan_v1_dense": apply_liger_kernel_to_hunyuan_v1_dense, + "hunyuan_v1_moe": apply_liger_kernel_to_hunyuan_v1_moe, + "exaone4": apply_liger_kernel_to_exaone4, +} + + +def _apply_liger_kernel(model_type: str, **kwargs) -> None: + """ + Applies Liger kernels based on the specified model type. The custom + kernels for the specified model type will be applied with the provided + keyword arguments, otherwise the default configuration will be used. + + ** Note: Calling _apply_liger_kernel() after model initialization + will not be able to fully patch models. This must be called before model initialization. + If the model has already been instantiated + + Args: + - model_type: the model types as defined in transformers/models/auto/modeling_auto.py + and specified in the model's config.json + - kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function. + """ + if not model_type: + logger.info("Model type was not provided. No Liger kernels will be applied.") + return + + if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys(): + logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.") + return + + apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type] + apply_fn_signature = inspect.signature(apply_fn) + + # Filter out the keyword arguments that are not supported by the apply function + applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters} + + logger.info(f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}") + + # Assume this is invoked pre-model initialization, so we only need to patch transformers code + apply_fn(**applicable_kwargs) + + +def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None: + """ + Applies Liger kernels to the provided model instance. + + Args: + - model: the model instance to apply Liger kernels to + - kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function. + """ + model_type = getattr(model, "config", None) and getattr(model.config, "model_type", None) + + if not model_type: + logger.info("Model type could not be determined from model config. No Liger kernels will be applied.") + return + + if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys(): + logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.") + return + + apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type] + apply_fn_signature = inspect.signature(apply_fn) + + # Filter out the keyword arguments that are not supported by the apply function + applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters} + logger.info( + f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}" + ) + + apply_fn(model=model, **applicable_kwargs) diff --git a/src/liger_kernel/transformers/multi_token_attention.py b/src/liger_kernel/transformers/multi_token_attention.py new file mode 100755 index 0000000000000000000000000000000000000000..38b5c6891d6fc7f8d003950640bac73d546945c5 --- /dev/null +++ b/src/liger_kernel/transformers/multi_token_attention.py @@ -0,0 +1,64 @@ +import math + +import torch +import torch.nn as nn + +from torch.nn.modules.utils import _pair + +from liger_kernel.ops import LigerMultiTokenAttentionFunction + + +class LigerMultiTokenAttention(nn.Module): + r""" + Multi-Token Attention: + out = mask_{0}(conv2d(softmax(mask_{-\inf}(scores)))) + + Reference: https://arxiv.org/pdf/2504.00927 + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + sparse: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.sparse = sparse + + self.weight = nn.Parameter(torch.empty(out_channels, in_channels // groups, *self.kernel_size)) + if bias: + self.bias = nn.Parameter(torch.empty(out_channels)) + else: + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, scores: torch.Tensor) -> torch.Tensor: + return LigerMultiTokenAttentionFunction.apply( + scores, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + self.sparse, + ) diff --git a/src/liger_kernel/transformers/poly_norm.py b/src/liger_kernel/transformers/poly_norm.py new file mode 100755 index 0000000000000000000000000000000000000000..24b991db37422b7f84573f9466126f37909a5524 --- /dev/null +++ b/src/liger_kernel/transformers/poly_norm.py @@ -0,0 +1,42 @@ +import torch +import torch.nn as nn + +from liger_kernel.ops import LigerPolyNormFunction + + +class LigerPolyNorm(nn.Module): + """ + PolyNorm layer wrapper for Liger kernel. + + PolyNorm formula: + y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b + where norm(u) = u / sqrt(mean(u²) + ε) + + Reference: + https://github.com/BryceZhuo/PolyCom/ + + Args: + eps: epsilon for numerical stability (default: 1e-6) + in_place: whether to in-place modify grad_output in backward to save memory (default: False). + Set to True to save memory if grad_output is not needed elsewhere. + """ + + def __init__(self, eps=1e-6, in_place=True): + super().__init__() + # Align with PolyCom reference: initialize weights to (1/3, 1/3, 1/3) and bias to 1.0 + self.weight = nn.Parameter(torch.full((3,), 1.0 / 3.0)) + self.bias = nn.Parameter(torch.tensor(1.0)) + self.variance_epsilon = eps + self.in_place = in_place + + def forward(self, hidden_states): + return LigerPolyNormFunction.apply( + hidden_states, + self.weight, + self.bias, + self.variance_epsilon, + self.in_place, + ) + + def extra_repr(self): + return f"weight_shape={tuple(self.weight.shape)}, eps={self.variance_epsilon}, in_place={self.in_place}" diff --git a/src/liger_kernel/transformers/qwen2vl_mrope.py b/src/liger_kernel/transformers/qwen2vl_mrope.py new file mode 100755 index 0000000000000000000000000000000000000000..75c2b623b65d5e0ddfcdfa1e6f05f1874fcfc92c --- /dev/null +++ b/src/liger_kernel/transformers/qwen2vl_mrope.py @@ -0,0 +1,20 @@ +from liger_kernel.ops import LigerQwen2VLMRopeFunction + + +def liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """ + Applies Multimodal Rotary Positional Embedding (M-RoPE) operation to query and key states. + + Args: + q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim). + k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim). + cos (torch.Tensor): The cosine tensor of shape (3, bsz, seq_len, head_dim). + sin (torch.Tensor): The sine tensor of shape (3, bsz, seq_len, head_dim). + mrope_section (List[int]): The multimodal rope section for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The query and key tensors after applying the M-RoPE operation. + """ + + return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim) diff --git a/src/liger_kernel/transformers/rms_norm.py b/src/liger_kernel/transformers/rms_norm.py new file mode 100755 index 0000000000000000000000000000000000000000..3f5aa7684d7d9324403944565c2459ae8e70b854 --- /dev/null +++ b/src/liger_kernel/transformers/rms_norm.py @@ -0,0 +1,91 @@ +import torch +import torch.nn as nn + +from liger_kernel.ops import LigerRMSNormFunction + + +class LigerRMSNorm(nn.Module): + def __init__( + self, + hidden_size, + eps=1e-6, + offset=0.0, + casting_mode="llama", + init_fn="ones", + in_place=True, + row_mode=None, + elementwise_affine=True, + ): + super().__init__() + assert init_fn in [ + "ones", + "zeros", + ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}" + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)) + else: + self.register_parameter("weight", None) + self.variance_epsilon, self.offset, self.casting_mode, self.in_place, self.row_mode = ( + eps, + offset, + casting_mode, + in_place, + row_mode, + ) + + def forward(self, hidden_states): + return LigerRMSNormFunction.apply( + hidden_states, + self.weight, + self.variance_epsilon, + self.offset, + self.casting_mode, + self.in_place, + self.row_mode, + ) + + def extra_repr(self): + return f"weight_shape={tuple(self.weight.shape) if self.weight is not None else None}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}, row_mode={self.row_mode}" + + +class LigerRMSNormForGemma(LigerRMSNorm): + def __init__( + self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=True, row_mode=None + ): + super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode) + + +class LigerRMSNormForGemma2(LigerRMSNorm): + def __init__( + self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False, row_mode=None + ): + super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode) + + +class LigerRMSNormForGemma3(LigerRMSNorm): + """Gemma3RMSNorm has a dim argument not hidden_size used in q_norm and k_norm.""" + + def __init__(self, dim, eps=0.000001, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False): + super().__init__(dim, eps, offset, casting_mode, init_fn, in_place) + + +class LigerRMSNormForOlmo2(LigerRMSNorm): + def __init__( + self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None + ): + super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode) + + +class LigerRMSNormForGlm4(LigerRMSNorm): + def __init__( + self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None + ): + super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode) + + +class LigerRMSNormForQwen3Next(LigerRMSNorm): + def __init__( + self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False, row_mode=None + ): + super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode) diff --git a/src/liger_kernel/transformers/rope.py b/src/liger_kernel/transformers/rope.py new file mode 100755 index 0000000000000000000000000000000000000000..ea2ca86ede0db4a0bd9b8349e78f0833c95e7a87 --- /dev/null +++ b/src/liger_kernel/transformers/rope.py @@ -0,0 +1,64 @@ +from typing import Tuple + +import torch + +from liger_kernel.ops import LigerRopeFunction + + +def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """ + Applies Rotary Positional Embedding (RoPE) operation to query and key states. + + Args: + q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim). + k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim). + cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim). + sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim). + position_ids (torch.Tensor, optional): The position ids tensor. Defaults to None. + unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The query and key tensors after applying the RoPE operation. + """ + + return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim) + + +def liger_rotary_pos_emb_vision( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Modified version of liger_rotary_pos_emb for qwen3_vl's apply_rotary_pos_emb_vision function. + Manually tranposed the input and output to match the expected shape for liger_rotary_pos_emb. + Reference: https://https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L116 + + Args: + q (torch.Tensor): The query tensor of shape (seq_length, num_heads, head_dim), + with stride (num_heads * head_dim, head_dim, 1). + k (torch.Tensor): The query tensor of shape (seq_length, num_heads, head_dim), + with stride (num_heads * head_dim, head_dim, 1). Same as q. + cos (torch.Tensor): The cosine tensor of shape (seq_length, head_dim). + sin (torch.Tensor): The sine tensor of shape (seq_length, head_dim). + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The query and key tensors with the same shape and stride as inputs. + """ + orig_q_dtype, orig_k_dtype = q.dtype, k.dtype + + # tranpose to (1, num_heads, seq_length, head_dim) and cast to float32 to match liger_rotary_pos_emb input shape + # also unsqueeze for batch dim + q32 = q.to(torch.float32).unsqueeze(0).transpose(1, 2) + k32 = k.to(torch.float32).unsqueeze(0).transpose(1, 2) + cos32 = cos.to(torch.float32) + sin32 = sin.to(torch.float32) + + q_out, k_out = liger_rotary_pos_emb(q32, k32, cos32, sin32) + + # transpose back to (seq_length, num_heads, head_dim) and cast back to original dtype + # also squeeze out batch dim + q_out = q_out.transpose(1, 2).squeeze(0).to(orig_q_dtype) + k_out = k_out.transpose(1, 2).squeeze(0).to(orig_k_dtype) + return q_out, k_out diff --git a/src/liger_kernel/transformers/softmax.py b/src/liger_kernel/transformers/softmax.py new file mode 100755 index 0000000000000000000000000000000000000000..1d81aa16304f1e01873adca6e39d8951d9b80ee9 --- /dev/null +++ b/src/liger_kernel/transformers/softmax.py @@ -0,0 +1,12 @@ +import torch +import torch.nn as nn + +from liger_kernel.ops import LigerSoftmaxFunction + + +class LigerSoftmax(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor): + return LigerSoftmaxFunction.apply(x) diff --git a/src/liger_kernel/transformers/sparsemax.py b/src/liger_kernel/transformers/sparsemax.py new file mode 100755 index 0000000000000000000000000000000000000000..af54aac9d889cde8ea7707c3615e9a239e6364c7 --- /dev/null +++ b/src/liger_kernel/transformers/sparsemax.py @@ -0,0 +1,16 @@ +import torch +import torch.nn as nn + +from liger_kernel.ops import LigerSparsemaxFunction + + +class LigerSparsemax(nn.Module): + def __init__(self, dim: int = -1): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return LigerSparsemaxFunction.apply(x, self.dim) + + def extra_repr(self) -> str: + return f"dim={self.dim}" diff --git a/src/liger_kernel/transformers/swiglu.py b/src/liger_kernel/transformers/swiglu.py new file mode 100755 index 0000000000000000000000000000000000000000..02bf7dadb9f306359f4dab8c874cb3f28e47d0f8 --- /dev/null +++ b/src/liger_kernel/transformers/swiglu.py @@ -0,0 +1,145 @@ +import torch +import torch.nn as nn + +from liger_kernel.ops import LigerSiLUMulFunction + + +class LigerSwiGLUMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + if config.hidden_act not in ["silu", "swish"]: + raise ValueError(f"Activation function {config.hidden_act} not supported.") + + def forward(self, x): + return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x))) + + +class LigerBlockSparseTop2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + + if config.hidden_act not in ["silu", "swish"]: + raise ValueError(f"Activation function {config.hidden_act} not supported.") + + def forward(self, x): + return self.w2(LigerSiLUMulFunction.apply(self.w1(x), self.w3(x))) + + +class LigerExperts(nn.Module): + """ + Patch MixtralExperts for transformers v5 or later to use LigerSiLUMulFunction + https://github.com/huggingface/transformers/blob/393b4b3d28e29b4b05b19b4b7f3242a7fc893637/src/transformers/models/mixtral/modeling_mixtral.py#L63 + """ + + def __init__(self, config): + super().__init__() + if "num_experts" in config: + # qwen3_moe, qwen3_next uses num_experts + self.num_experts = config.num_experts + else: + self.num_experts = config.num_local_experts + if "moe_intermediate_size" in config: + # qwen3_moe, qwen3_next uses moe_intermediate_size + self.intermediate_dim = config.moe_intermediate_size + else: + self.intermediate_dim = config.intermediate_size + + self.hidden_dim = config.hidden_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + + if config.hidden_act not in ["silu", "swish"]: + raise ValueError(f"Activation function {config.hidden_act} not supported.") + + def forward(self, hidden_states, top_k_index, top_k_weights): + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = LigerSiLUMulFunction.apply(gate, up) + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + + return final_hidden_states + + +class LigerPhi3SwiGLUMLP(nn.Module): + """ + Patch Phi3MLP to use LigerSiLUMulFunction + https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/models/phi3/modeling_phi3.py#L241 + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + if config.hidden_act not in ["silu", "swish"]: + raise ValueError(f"Activation function {config.hidden_act} not supported.") + + def forward(self, x): + up_states = self.gate_up_proj(x) + gate, up_states = up_states.chunk(2, dim=-1) + return self.down_proj(LigerSiLUMulFunction.apply(gate, up_states)) + + +class LigerQwen3MoeSwiGLUMLP(nn.Module): + """ + Patch Qwen3MoeMLP to use LigerSiLUMulFunction. + https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/qwen3_moe/modular_qwen3_moe.py#L57 + """ + + def __init__(self, config, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + if config.hidden_act not in ["silu", "swish"]: + raise ValueError(f"Activation function {config.hidden_act} not supported.") + + def forward(self, x): + return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x))) + + +class LigerHunyuanV1SwiGLUMLP(nn.Module): + def __init__(self, config, layer_idx=None, is_shared_mlp=False): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.layer_idx = layer_idx + if config.hidden_act not in ["silu", "swish"]: + raise ValueError(f"Activation function {config.hidden_act} not supported.") + + def forward(self, x): + return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x))) diff --git a/src/liger_kernel/transformers/tiled_mlp.py b/src/liger_kernel/transformers/tiled_mlp.py new file mode 100755 index 0000000000000000000000000000000000000000..b72507b2eeb19ccfe931c309fba52ec2be3f77ab --- /dev/null +++ b/src/liger_kernel/transformers/tiled_mlp.py @@ -0,0 +1,125 @@ +from typing import Optional + +import torch.nn as nn + +from liger_kernel.ops import LigerGELUMulFunction +from liger_kernel.ops import LigerSiLUMulFunction +from liger_kernel.ops import apply_tiled_mlp + + +class LigerTiledGEGLUMLP(nn.Module): + """ + Memory-efficient GEGLU MLP using tiled computation. + + This module combines GEGLU activation with tiled processing to handle + very long sequences efficiently. The forward pass is recomputed during + backward to save memory. + + Args: + config: Model configuration with hidden_size and intermediate_size attributes + num_shards: Number of shards to split the sequence. If None, automatically + calculated as ceil(seqlen / hidden_size) + """ + + def __init__(self, config, num_shards: Optional[int] = None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.num_shards = num_shards + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + # Validate activation function + if hasattr(config, "hidden_act") and config.hidden_act not in [ + "gelu", + "gelu_new", + "gelu_pytorch_tanh", + ]: + raise ValueError(f"LigerTiledGEGLUMLP requires GELU activation, got {config.hidden_act}") + + def _mlp_forward(self, module, x): + """Internal MLP forward function for tiled computation.""" + gate = module.gate_proj(x) + up = module.up_proj(x) + return module.down_proj(LigerGELUMulFunction.apply(gate, up)) + + def forward(self, x): + """ + Forward pass with tiled computation. + + Args: + x: Input tensor of shape [batch_size, seq_len, hidden_size] + or [seq_len, hidden_size] + + Returns: + Output tensor of the same shape as input + """ + compute_params = [p for p in self.parameters() if p.requires_grad] + + return apply_tiled_mlp( + fn=self._mlp_forward, + mlp_module=self, + x=x, + num_shards=self.num_shards, + compute_params=compute_params, + ) + + +class LigerTiledSwiGLUMLP(nn.Module): + """ + Memory-efficient SwiGLU MLP using tiled computation. + + This module combines SwiGLU activation with tiled processing to handle + very long sequences efficiently. The forward pass is recomputed during + backward to save memory. + + Args: + config: Model configuration with hidden_size and intermediate_size attributes + num_shards: Number of shards to split the sequence. If None, automatically + calculated as ceil(seqlen / hidden_size) + """ + + def __init__(self, config, num_shards: Optional[int] = None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.num_shards = num_shards + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + # Validate activation function + if hasattr(config, "hidden_act") and config.hidden_act not in ["silu", "swish"]: + raise ValueError(f"LigerTiledSwiGLUMLP requires SiLU/Swish activation, got {config.hidden_act}") + + def _mlp_forward(self, module, x): + """Internal MLP forward function for tiled computation.""" + gate = module.gate_proj(x) + up = module.up_proj(x) + return module.down_proj(LigerSiLUMulFunction.apply(gate, up)) + + def forward(self, x): + """ + Forward pass with tiled computation. + + Args: + x: Input tensor of shape [batch_size, seq_len, hidden_size] + or [seq_len, hidden_size] + + Returns: + Output tensor of the same shape as input + """ + compute_params = [p for p in self.parameters() if p.requires_grad] + + return apply_tiled_mlp( + fn=self._mlp_forward, + mlp_module=self, + x=x, + num_shards=self.num_shards, + compute_params=compute_params, + ) diff --git a/src/liger_kernel/transformers/trainer/__init__.py b/src/liger_kernel/transformers/trainer/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..df5de2038ace73024ec5d42933990cd178aaeeff --- /dev/null +++ b/src/liger_kernel/transformers/trainer/__init__.py @@ -0,0 +1,4 @@ +try: + from liger_kernel.transformers.trainer.orpo_trainer import LigerORPOTrainer # noqa: F401 +except ImportError: + raise ImportError("Please `pip install trl` to use LigerORPOTrainer") diff --git a/src/liger_kernel/transformers/trainer/orpo_trainer.py b/src/liger_kernel/transformers/trainer/orpo_trainer.py new file mode 100755 index 0000000000000000000000000000000000000000..6ae10b35b99770e89f7fdd1481d210e55b1d71f3 --- /dev/null +++ b/src/liger_kernel/transformers/trainer/orpo_trainer.py @@ -0,0 +1,130 @@ +from typing import Dict +from typing import List +from typing import Literal +from typing import Tuple +from typing import Union + +import torch +import torch.nn as nn + +from torch.distributed.fsdp import FullyShardedDataParallel +from trl.trainer import ORPOTrainer + +from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss +from liger_kernel.transformers.fsdp import _FSDPForwardRedirection + + +class LigerORPOTrainer(ORPOTrainer): + def concatenated_forward( + self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """ + Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + We do this to avoid doing two forward passes, because it's faster for FSDP. + """ + concatenated_batch = self.concatenated_inputs( + batch, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + padding_value=self.padding_value, + device=self.accelerator.device, + ) + + model_kwargs = ( + { + "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]), + } + if self.is_encoder_decoder + else {} + ) + + if self.aux_loss_enabled: + model_kwargs["output_router_logits"] = True + + if self.is_encoder_decoder: + labels = concatenated_batch["concatenated_labels"].clone() + else: + labels = concatenated_batch["concatenated_input_ids"].clone() + attention_mask = concatenated_batch["concatenated_attention_mask"] + labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id) + + if isinstance(model, FullyShardedDataParallel): + outputs = _FSDPForwardRedirection()( + model, + model._fsdp_wrapped_module.model, + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + use_cache=False, + **model_kwargs, + ) + else: + if isinstance(model, torch.nn.DataParallel): + model = model.module + outputs = model.model( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + use_cache=False, + **model_kwargs, + ) + + orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta) + + def orpo_partial(lm_head, last_hidden_state, concatenated_labels, nll_target): + return orpo_loss_fn( + lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias, nll_target=nll_target + ) + + orpo_loss, aux_outputs = _FSDPForwardRedirection()( + model, + orpo_partial, + model.lm_head, + outputs.last_hidden_state[:, :-1] if not self.is_encoder_decoder else outputs.last_hidden_state, + concatenated_batch["concatenated_labels"][:, 1:] + if not self.is_encoder_decoder + else concatenated_batch["concatenated_labels"], + labels[:, 1:] if not self.is_encoder_decoder else labels, + ) + # if aux_loss_enabled, add the aux_loss to the orpo_loss + if self.aux_loss_enabled: + orpo_loss += self.aux_loss_coef * outputs.aux_loss + + return orpo_loss, aux_outputs + + def get_batch_loss_metrics( + self, + model, + batch: Dict[str, Union[List, torch.LongTensor]], + train_eval: Literal["train", "eval"] = "train", + ): + """Compute the ORPO loss and other metrics for the given batch of inputs for train or test.""" + metrics = {} + loss, aux_outputs = self.concatenated_forward(model, batch) + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + policy_nll_loss, + ) = aux_outputs[:5] + + # return loss, metrics + chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = aux_outputs[5:] + + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + prefix = "eval_" if train_eval == "eval" else "" + metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean() + metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean() + metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean() + metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean() + metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean() + metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean() + metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean() + metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean() + metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean() + metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio + metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen + for k, v in metrics.items(): + metrics[k] = v.item() + + return loss, metrics diff --git a/src/liger_kernel/transformers/trainer_integration.py b/src/liger_kernel/transformers/trainer_integration.py new file mode 100755 index 0000000000000000000000000000000000000000..623ceab543aaa4253f56217ac2a93f0597644b5d --- /dev/null +++ b/src/liger_kernel/transformers/trainer_integration.py @@ -0,0 +1,2 @@ +# To not break HF Trainer integration +from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401 diff --git a/src/liger_kernel/transformers/tvd.py b/src/liger_kernel/transformers/tvd.py new file mode 100755 index 0000000000000000000000000000000000000000..b57a4898ca2da8c809a50f059d56b3b9c9c9608c --- /dev/null +++ b/src/liger_kernel/transformers/tvd.py @@ -0,0 +1,13 @@ +import torch.nn as nn + +from liger_kernel.ops import LigerTVDLossFunction + + +class LigerTVDLoss(nn.Module): + def __init__(self, reduction="batchmean", ignore_index: int = -100): + super(LigerTVDLoss, self).__init__() + self.reduction = reduction + self.ignore_index = ignore_index + + def forward(self, p, q, shift_labels=None): + return LigerTVDLossFunction.apply(p, q, shift_labels, self.reduction, self.ignore_index) diff --git a/src/liger_kernel/triton/__init__.py b/src/liger_kernel/triton/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..d373966a9bb488ce448ec6989c919967382fe8c7 --- /dev/null +++ b/src/liger_kernel/triton/__init__.py @@ -0,0 +1 @@ +from liger_kernel.triton.monkey_patch import apply_liger_triton_cache_manager # noqa: F401 diff --git a/src/liger_kernel/triton/monkey_patch.py b/src/liger_kernel/triton/monkey_patch.py new file mode 100755 index 0000000000000000000000000000000000000000..bac4a6a0d6a8fc74b56562bdba3c659e175c39ca --- /dev/null +++ b/src/liger_kernel/triton/monkey_patch.py @@ -0,0 +1,40 @@ +import os +import random + +from triton.runtime.cache import FileCacheManager + + +class LigerTritonFileCacheManager(FileCacheManager): + def put(self, data, filename, binary=True) -> str: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + binary = isinstance(data, bytes) + if not binary: + data = str(data) + assert self.lock_path is not None + filepath = self._make_path(filename) + # Random ID to avoid any collisions + rnd_id = random.randint(0, 1000000) + # we use the PID incase a bunch of these around so we can see what PID made it + pid = os.getpid() + # use temp dir to be robust against program interruptions + temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}") + os.makedirs(temp_dir, exist_ok=True) + temp_path = os.path.join(temp_dir, filename) + + mode = "wb" if binary else "w" + with open(temp_path, mode) as f: + f.write(data) + # Replace is guaranteed to be atomic on POSIX systems if it succeeds + # so filepath cannot see a partial write + os.replace(temp_path, filepath) + os.removedirs(temp_dir) + return filepath + + +def apply_liger_triton_cache_manager(): + """ + Experimental feature to get around transient FileNotFoundError in triton compilation. + For more details please see https://github.com/triton-lang/triton/pull/4295 + """ + os.environ["TRITON_CACHE_MANAGER"] = "liger_kernel.triton.monkey_patch:LigerTritonFileCacheManager" diff --git a/src/liger_kernel/utils.py b/src/liger_kernel/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..a6bbc31b5e9b3085d3c73ded576d74653128cc02 --- /dev/null +++ b/src/liger_kernel/utils.py @@ -0,0 +1,125 @@ +try: + import peft # noqa: F401 + + PEFT_AVAILABLE = True +except ImportError: + PEFT_AVAILABLE = False + +import torch + + +def is_peft_available(): + return PEFT_AVAILABLE + + +def infer_comm_backend(): + """ + Get communication backend name based on the environment. + """ + if torch.distributed.is_nccl_available(): + # Works for Nvidia + # TODO: nccl may not work for AMD decices that may require use of rccl. + return "nccl" + elif is_npu_available(): + # Use Ascend NPU if available (torch.npu) + # Ascend is not standard torch backend and requires extension. + # Assume that it is installed if NPUs are being used in + # multi device environment. + return "ascend" + # XPU (Intel) if available + elif torch.distributed.distributed_c10d.is_xccl_available(): + return "xccl" + elif torch.distributed.is_mpi_available(): + # CPU backend, first option + return "mpi" + elif torch.distributed.is_gloo_available(): + # CPU backend, backup option + return "gloo" + else: + raise RuntimeError("There is no distributed backend available.") + + +def infer_device(): + """ + Get current device name based on available devices + """ + if torch.cuda.is_available(): # Works for both Nvidia and AMD + return "cuda" + # Use Ascend NPU if available (torch.npu) + elif is_npu_available(): + return "npu" + # XPU (Intel) if available + elif torch.xpu.is_available(): + return "xpu" + else: + return "cpu" + + +def is_npu_available() -> bool: + """Detect Ascend NPU availability.""" + try: + from transformers.utils import is_torch_npu_available + + return is_torch_npu_available() + except Exception: + return False + + +def transformers_version_dispatch( + required_version: str, + before_fn, + after_fn, + before_args: tuple = (), + after_args: tuple = (), + before_kwargs: dict = None, + after_kwargs: dict = None, +): + """ + Dispatches to different functions based on package version comparison. + + Args: + required_version: Version to compare against (e.g. "4.48.0") + before_fn: Function to call if package_version < required_version + after_fn: Function to call if package_version >= required_version + before_args: Positional arguments for before_fn + after_args: Positional arguments for after_fn + before_kwargs: Keyword arguments for before_fn + after_kwargs: Keyword arguments for after_fn + + Returns: + Result from either before_fn or after_fn + + Example: + >>> rotary_emb = transformers_version_dispatch( + ... "4.48.0", + ... LlamaRotaryEmbedding, + ... LlamaRotaryEmbedding, + ... before_args=(head_dim,), + ... after_args=(LlamaConfig(head_dim=head_dim),), + ... before_kwargs={'device': device}, + ... after_kwargs={'device': device} + ... ) + """ + from packaging import version + from transformers import __version__ as transformers_version + + before_kwargs = before_kwargs or {} + after_kwargs = after_kwargs or {} + + if version.parse(transformers_version) < version.parse(required_version): + return before_fn(*before_args, **before_kwargs) + else: + return after_fn(*after_args, **after_kwargs) + + +def get_total_gpu_memory() -> int: + """Returns total GPU memory in GBs.""" + device = infer_device() + if device == "cuda": + return torch.cuda.get_device_properties(0).total_memory // (1024**3) + elif device == "xpu": + return torch.xpu.get_device_properties(0).total_memory // (1024**3) + elif device == "npu": + return torch.npu.get_device_properties(0).total_memory // (1024**3) + else: + raise RuntimeError(f"Unsupported device: {device}") diff --git a/test/__init__.py b/test/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/test/chunked_loss/__init__.py b/test/chunked_loss/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/test/chunked_loss/test_cosine_loss.py b/test/chunked_loss/test_cosine_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..b2711c04d875666830bb9a9c3030d1829ecac253 --- /dev/null +++ b/test/chunked_loss/test_cosine_loss.py @@ -0,0 +1,320 @@ +import pytest +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityFunction +from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityLoss +from liger_kernel.chunked_loss.functional import liger_fused_linear_cosine +from liger_kernel.utils import infer_device +from test.utils import HFDistillationLoss +from test.utils import assert_verbose_allclose +from test.utils import set_seed + +device = infer_device() +set_seed() + + +class HFCosineLoss(HFDistillationLoss): + """ + implementation of a distilltion loss using cosine similarity + """ + + def __init__( + self, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + ): + super().__init__( + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + ignore_index=ignore_index, + temperature=temperature, + ) + + def distillation_loss(self, student_logits, teacher_logits, target=None, ignore_index=None, beta=1.0, **kwargs): + # Compute normalized logits + print(f"student_logits.shape: {student_logits.shape}") + student_norm = F.normalize(student_logits, p=2, dim=-1) + teacher_norm = F.normalize(teacher_logits, p=2, dim=-1) + # cosine_sim = (student_norm * teacher_norm).sum(dim=1).mean() + # loss = beta * (1 - cosine_sim) + cosine_sim = F.cosine_similarity(student_norm, teacher_norm, dim=-1) + + loss = beta * (1 - cosine_sim) + return loss.mean() + + +class TorchCosineLoss(torch.nn.Module): + """ + Reference implementation for Cosine Similarity Loss using standard torch operations. + Computes the loss as 1 - cosine_similarity averaged over all tokens. + """ + + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool, + device: torch.device, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + beta: float = 1.0, + ignore_index: int = -100, + temperature: float = 1.0, + ): + super().__init__() + # Note: student inputs are expected to have hidden size H//2 while teacher inputs have H. + self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=bias, dtype=dtype, device=device) + self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype, device=device) + self.beta = beta + self.cosine = HFCosineLoss( + ignore_index=ignore_index, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + temperature=temperature, + ).get_batch_loss_metrics + + def forward(self, student_input, teacher_input, target): + loss = self.cosine( + student_input, + self.student_lin.weight, + teacher_input, + self.teacher_lin.weight, + target, + self.student_lin.bias, + self.teacher_lin.bias, + beta=self.beta, + ) + return loss + + +class LigerCosineLoss(torch.nn.Module): + """ + Liger implementation that uses fused cosine similarity loss. + """ + + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool, + device: torch.device, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + compiled: bool = True, + chunk_size: int = 1024, + ): + super().__init__() + self.chunked_cosine = LigerFusedLinearCosineSimilarityLoss( + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + beta=beta, + ignore_index=ignore_index, + temperature=temperature, + compiled=compiled, + chunk_size=chunk_size, + ) + self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=bias, dtype=dtype, device=device) + self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype, device=device) + + def forward(self, student_input, teacher_input, target): + return self.chunked_cosine( + student_input, + self.student_lin.weight, + teacher_input, + self.teacher_lin.weight, + target, + self.student_lin.bias, + self.teacher_lin.bias, + ) + + +############################################################################### +# Test correctness of the module implementations +############################################################################### + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 32, 128), # H must be even + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize( + "temperature, weight_hard_loss, weight_soft_loss, beta", + [ + (1.0, 0.5, 0.5, 0.5), + (2.0, 0.0, 1.0, 0.8), + (0.5, 1.0, 0.0, 0.2), + ], +) +def test_correctness( + B, T, H, V, scalar, dtype, atol, rtol, bias, temperature, weight_hard_loss, weight_soft_loss, beta +): + torch_cosine = TorchCosineLoss( + H=H, + V=V, + dtype=dtype, + bias=bias, + device=device, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + temperature=temperature, + beta=beta, + ) + liger_cosine = LigerCosineLoss( + H=H, + V=V, + dtype=dtype, + bias=bias, + device=device, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + temperature=temperature, + beta=beta, + ) + # Ensure both implementations start with the same weights and biases. + torch_cosine.student_lin.weight.data = liger_cosine.student_lin.weight.data = torch.rand( + V, H // 2, device=device, dtype=dtype + ) + torch_cosine.teacher_lin.weight.data = liger_cosine.teacher_lin.weight.data = torch.rand( + V, H, device=device, dtype=dtype + ) + if bias: + torch_cosine.student_lin.bias.data = liger_cosine.student_lin.bias.data = torch.rand( + V, device=device, dtype=dtype + ) + torch_cosine.teacher_lin.bias.data = liger_cosine.teacher_lin.bias.data = torch.rand( + V, device=device, dtype=dtype + ) + + _tensor = torch.rand(B * T, H // 2, device=device, dtype=dtype) * scalar + student_input1 = _tensor.clone().detach().requires_grad_(True) + student_input2 = _tensor.clone().detach().requires_grad_(True) + teacher_input = torch.rand(B * T, H, device=device, dtype=dtype) * scalar + # Dummy target (not used in cosine computation) + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + loss1 = torch_cosine(student_input1, teacher_input, target) + loss2 = liger_cosine(student_input2, teacher_input, target) + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + loss1.backward() + print("loss1 shape : {loss1.shape}") + loss2.backward() + + assert_verbose_allclose(student_input1.grad, student_input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose( + torch_cosine.student_lin.weight.grad, liger_cosine.student_lin.weight.grad, atol=atol, rtol=rtol + ) + if bias: + assert_verbose_allclose( + torch_cosine.student_lin.bias.grad, liger_cosine.student_lin.bias.grad, atol=atol, rtol=rtol + ) + + +############################################################################### +# Test correctness of the functional interface +############################################################################### + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (2, 2, 8, 8), + (9, 7, 40, 40), # H must be even + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-2), + (1.0, torch.float32, 1e-4, 5e-3), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize( + "temperature, weight_hard_loss, weight_soft_loss, beta, ignore_index", + [ + (1.0, 0.5, 0.5, 0.5, -100), + (2.0, 0.1, 0.9, 0.5, 42), + ], +) +def test_correctness_functional( + B, T, H, V, scalar, dtype, bias, weight_hard_loss, weight_soft_loss, beta, ignore_index, temperature, atol, rtol +): + # Prepare weights and biases for functional testing. + student_weight1 = torch.rand(V, H // 2, device=device, dtype=dtype).detach().clone().requires_grad_(True) + student_weight2 = student_weight1.clone().detach().requires_grad_(True) + teacher_weight = torch.rand(V, H, device=device, dtype=dtype) + + if bias: + student_bias1 = torch.rand(V, device=device, dtype=dtype).detach().clone().requires_grad_(True) + student_bias2 = student_bias1.clone().detach().requires_grad_(True) + teacher_bias = torch.rand(V, device=device, dtype=dtype) + else: + student_bias1 = student_bias2 = teacher_bias = None + + _tensor = torch.rand(B * T, H // 2, device=device, dtype=dtype) * scalar + student_input1 = _tensor.clone().detach().requires_grad_(True) + student_input2 = _tensor.clone().detach().requires_grad_(True) + teacher_input = torch.rand(B * T, H, device=device, dtype=dtype) * scalar + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + # Functional call using the fused cosine similarity function + output1 = liger_fused_linear_cosine( + student_input1, + student_weight1, + teacher_input, + teacher_weight, + target, + student_bias1, + teacher_bias, + weight_hard_loss, + weight_soft_loss, + beta, + ignore_index, + temperature, + True, + 1024, + ) + output2 = LigerFusedLinearCosineSimilarityFunction.apply( + student_input2, + student_weight2, + teacher_input, + teacher_weight, + target, + student_bias2, + teacher_bias, + weight_hard_loss, + weight_soft_loss, + beta, + ignore_index, + temperature, + True, + 1024, + ) + + assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol) + output1.backward() + output2.backward() + + assert_verbose_allclose(student_input1.grad, student_input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose(student_weight1.grad, student_weight2.grad, atol=atol, rtol=rtol) + if bias: + assert_verbose_allclose(student_bias1.grad, student_bias2.grad, atol=atol, rtol=rtol) diff --git a/test/chunked_loss/test_cpo_loss.py b/test/chunked_loss/test_cpo_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..c996e57f9de2628ee18d7541d3cdcaeed4a9c98f --- /dev/null +++ b/test/chunked_loss/test_cpo_loss.py @@ -0,0 +1,302 @@ +from typing import Tuple + +import pytest +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss import LigerFusedLinearCPOLoss +from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction +from liger_kernel.chunked_loss.functional import liger_fused_linear_cpo +from liger_kernel.utils import infer_device +from test.utils import HFAlignmentLoss +from test.utils import assert_verbose_allclose +from test.utils import set_seed + +device = infer_device() + +# set random seed globally +set_seed() + + +class HFCPOLoss(HFAlignmentLoss): + """ + HF's implementation of CPO loss in TRL. https://github.com/huggingface/trl/blob/main/trl/trainer/cpo_trainer.py + """ + + def __init__( + self, + alpha: float = 1.0, + beta: float = 0.1, + ignore_index: int = -100, + label_smoothing: float = 0.0, + simpo_gamma: float = 0.5, + loss_type: str = "sigmoid", + ): + super().__init__(alpha=alpha, beta=beta, ignore_index=ignore_index) + # Sigmoid defaults to the CPO loss defined in the paper listed above. + self.loss_type = loss_type + self.label_smoothing = label_smoothing + self.simpo_gamma = simpo_gamma + + def alignment_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Compute the CPO loss for a batch of policy and reference model log probabilities. + + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + + Returns: + A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). + The losses tensor contains the CPO loss for each example in the batch. + The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. + """ + logits = policy_chosen_logps - policy_rejected_logps + + # The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5. + # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and + # calculates a conservative CPO loss. + if self.loss_type == "sigmoid": + # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0. + losses = ( + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) + elif self.loss_type == "simpo": + logits = logits - (self.simpo_gamma / self.beta) + losses = ( + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) + else: + raise ValueError(f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid']") + + chosen_rewards = self.beta * policy_chosen_logps + rejected_rewards = self.beta * policy_rejected_logps + + return losses, chosen_rewards, rejected_rewards + + +class TorchLMHeadCPO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + alpha: float = 1.0, + label_smoothing: float = 0.0, + loss_type: str = "sigmoid", + simpo_gamma: float = 0.5, + ): + super().__init__() + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.cpo_loss = HFCPOLoss( + ignore_index=ignore_index, + beta=beta, + loss_type=loss_type, + label_smoothing=label_smoothing, + simpo_gamma=simpo_gamma, + ).get_batch_loss_metrics + self.average_log_prob = loss_type == "simpo" + + def forward(self, x, y): + return self.cpo_loss(self.lin.weight, x, y, self.lin.bias, average_log_prob=self.average_log_prob) + + +class LigerLMHeadCPO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + alpha: float = 1.0, + label_smoothing: float = 0.0, + ): + super().__init__() + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.cpo_loss = LigerFusedLinearCPOLoss( + ignore_index=ignore_index, + beta=beta, + alpha=alpha, + label_smoothing=label_smoothing, + ) + + def forward(self, x, y): + return self.cpo_loss(self.lin.weight, x, y, self.lin.bias) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-2), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)]) +@pytest.mark.parametrize("label_smoothing", [0.0, 0.1]) +def test_correctness( + B, + T, + H, + V, + scalar, + dtype, + atol, + rtol, + bias, + ignore_index, + beta, + alpha, + label_smoothing, +): + B = 2 * B # cpo loss requires B to be even + + torch_lm_head_cpo = TorchLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + label_smoothing=label_smoothing, + ) + liger_lm_head_cpo = LigerLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + label_smoothing=label_smoothing, + ) + + torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + + if bias: + torch_lm_head_cpo.lin.bias.data = liger_lm_head_cpo.lin.bias.data = torch.randn(V, device=device, dtype=dtype) + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + loss1, aggregated_aux_outputs1 = torch_lm_head_cpo(input1, target) + loss2, aggregated_aux_outputs2 = liger_lm_head_cpo(input2, target) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) + + for i in range(len(aggregated_aux_outputs1)): + assert_verbose_allclose( + aggregated_aux_outputs1[i], + aggregated_aux_outputs2[i], + atol=atol, + rtol=rtol, + ) + + loss1.backward() + loss2.backward() + + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose( + torch_lm_head_cpo.lin.weight.grad, + liger_lm_head_cpo.lin.weight.grad, + atol=atol, + rtol=rtol, + ) + if bias: + assert_verbose_allclose( + torch_lm_head_cpo.lin.bias.grad, + liger_lm_head_cpo.lin.bias.grad, + atol=atol, + rtol=rtol, + ) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (2, 2, 8, 8), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias): + B = 2 * B + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + + _weight = torch.randn(V, H, device=device, dtype=dtype) + weight1 = _weight.detach().clone().requires_grad_(True) + weight2 = _weight.detach().clone().requires_grad_(True) + + _bias = torch.randn(V, device=device, dtype=dtype) if bias else None + bias1 = _bias.detach().clone().requires_grad_(True) if bias else None + bias2 = _bias.detach().clone().requires_grad_(True) if bias else None + + loss1, aggregated_aux_outputs1 = LigerFusedLinearCPOFunction.apply(input1, weight1, target, bias1) + loss2, aggregated_aux_outputs2 = liger_fused_linear_cpo(input2, weight2, target, bias2) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + loss1.backward() + loss2.backward() + + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose(weight1.grad, weight2.grad, atol=atol, rtol=rtol) + if bias: + assert_verbose_allclose(bias1.grad, bias2.grad, atol=atol, rtol=rtol) diff --git a/test/chunked_loss/test_dpo_loss.py b/test/chunked_loss/test_dpo_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..de5762f26682e0337ab1f1ec382695d0578ae3d1 --- /dev/null +++ b/test/chunked_loss/test_dpo_loss.py @@ -0,0 +1,938 @@ +import pytest +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss +from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction +from liger_kernel.chunked_loss.functional import liger_fused_linear_dpo +from liger_kernel.utils import infer_device +from test.utils import HFAlignmentLoss +from test.utils import assert_verbose_allclose +from test.utils import set_seed + +device = infer_device() + +# set random seed globally +set_seed() + + +class HFDPOLoss(HFAlignmentLoss): + """ + Implementation of the Direct Preference Optimization (DPO) loss, + adapted from Hugging Face's implementation. + Reference: https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py + """ + + def __init__( + self, + ignore_index: int = -100, + beta: float = 0.1, + use_ref_model: bool = True, + compute_nll_loss: bool = False, + ): + super().__init__( + beta=beta, + ignore_index=ignore_index, + use_ref_model=use_ref_model, + compute_nll_loss=compute_nll_loss, + ) + + def alignment_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + ref_chosen_logps: torch.FloatTensor, + ref_rejected_logps: torch.FloatTensor, + ): + """Compute DPO loss for a batch of policy log probabilities. + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + + Returns: + The losses tensor contains the DPO loss for each example in the batch. + """ + # Derived from https://huggingface.co/papers/2305.18290 + chosen_logratios = policy_chosen_logps - ref_chosen_logps + rejected_logratios = policy_rejected_logps - ref_rejected_logps + + chosen_rewards = self.beta * chosen_logratios + rejected_rewards = self.beta * rejected_logratios + + logits_diff = self.beta * (chosen_logratios - rejected_logratios) + losses = -F.logsigmoid(logits_diff) + return losses, chosen_rewards, rejected_rewards + + +class HFAPOZeroLoss(HFAlignmentLoss): + """ + Implementation of the APO-zero loss. + Reference: https://huggingface.co/papers/2408.06266 + """ + + def __init__( + self, + ignore_index: int = -100, + beta: float = 0.1, + use_ref_model: bool = True, + compute_nll_loss: bool = False, + ): + super().__init__( + beta=beta, + ignore_index=ignore_index, + use_ref_model=use_ref_model, + compute_nll_loss=compute_nll_loss, + ) + + def alignment_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + ref_chosen_logps: torch.FloatTensor, + ref_rejected_logps: torch.FloatTensor, + ): + """Compute APO-zero loss for a batch of policy log probabilities. + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + + Returns: + The losses tensor contains the APO-zero loss for each example in the batch. + """ + # Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266) + chosen_logratios = policy_chosen_logps - ref_chosen_logps + rejected_logratios = policy_rejected_logps - ref_rejected_logps + + chosen_rewards = self.beta * chosen_logratios + rejected_rewards = self.beta * rejected_logratios + + # Use this loss when you believe the chosen outputs are better than your model's default output + losses_chosen = 1 - F.sigmoid(self.beta * chosen_logratios) # Increase chosen likelihood + losses_rejected = F.sigmoid(self.beta * rejected_logratios) # Decrease rejected likelihood + losses = losses_chosen + losses_rejected + + return losses, chosen_rewards, rejected_rewards + + +class HFAPODownLoss(HFAlignmentLoss): + """ + Implementation of the APO-down loss. + Reference: https://huggingface.co/papers/2408.06266 + """ + + def __init__( + self, + ignore_index: int = -100, + beta: float = 0.1, + use_ref_model: bool = True, + compute_nll_loss: bool = False, + ): + super().__init__( + beta=beta, + ignore_index=ignore_index, + use_ref_model=use_ref_model, + compute_nll_loss=compute_nll_loss, + ) + + def alignment_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + ref_chosen_logps: torch.FloatTensor, + ref_rejected_logps: torch.FloatTensor, + ): + """Compute APO-down loss for a batch of policy log probabilities. + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + + Returns: + The losses tensor contains the APO-down loss for each example in the batch. + """ + # Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266) + chosen_logratios = policy_chosen_logps - ref_chosen_logps + rejected_logratios = policy_rejected_logps - ref_rejected_logps + + chosen_rewards = self.beta * chosen_logratios + rejected_rewards = self.beta * rejected_logratios + + # Use this loss when you believe the chosen outputs are worse than your model's default output. + # Decrease chosen likelihood and decrease rejected likelihood more + losses_chosen = F.sigmoid(self.beta * chosen_logratios) + losses_rejected = 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_logratios)) + losses = losses_chosen + losses_rejected + + return losses, chosen_rewards, rejected_rewards + + +class HFSPPPOHARDLoss(HFAlignmentLoss): + def __init__( + self, + ignore_index: int = -100, + beta: float = 0.1, + use_ref_model: bool = True, + compute_nll_loss: bool = False, + ): + super().__init__( + beta=beta, + ignore_index=ignore_index, + use_ref_model=use_ref_model, + compute_nll_loss=compute_nll_loss, + ) + + def alignment_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + ref_chosen_logps: torch.FloatTensor, + ref_rejected_logps: torch.FloatTensor, + ): + chosen_logratios = policy_chosen_logps - ref_chosen_logps + rejected_logratios = policy_rejected_logps - ref_rejected_logps + + chosen_rewards = self.beta * chosen_logratios + rejected_rewards = self.beta * rejected_logratios + + a = policy_chosen_logps - ref_chosen_logps + b = policy_rejected_logps - ref_rejected_logps + losses = (a - 0.5 / self.beta) ** 2 + (b + 0.5 / self.beta) ** 2 + + return losses, chosen_rewards, rejected_rewards + + +class HFNCAPAIRLoss(HFAlignmentLoss): + def __init__( + self, + ignore_index: int = -100, + beta: float = 0.1, + use_ref_model: bool = True, + compute_nll_loss: bool = False, + ): + super().__init__( + beta=beta, + ignore_index=ignore_index, + use_ref_model=use_ref_model, + compute_nll_loss=compute_nll_loss, + ) + + def alignment_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + ref_chosen_logps: torch.FloatTensor, + ref_rejected_logps: torch.FloatTensor, + ): + chosen_logratios = policy_chosen_logps - ref_chosen_logps + rejected_logratios = policy_rejected_logps - ref_rejected_logps + + chosen_rewards = self.beta * chosen_logratios + rejected_rewards = self.beta * rejected_logratios + + losses = ( + -F.logsigmoid(chosen_rewards) - 0.5 * F.logsigmoid(-chosen_rewards) - 0.5 * F.logsigmoid(-rejected_rewards) + ) + + return losses, chosen_rewards, rejected_rewards + + +class TorchLMHeadDPO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + ref_bias: bool = False, + compute_nll_loss: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + ): + super().__init__() + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=ref_bias, dtype=dtype) + self.dpo_loss = HFDPOLoss( + ignore_index=ignore_index, + beta=beta, + use_ref_model=True, + compute_nll_loss=compute_nll_loss, + ).get_batch_loss_metrics + + def forward(self, x, ref_x, y): + return self.dpo_loss( + self.lin.weight, + x, + y, + self.lin.bias, + ref_x, + self.ref_lin.weight, + self.ref_lin.bias, + average_log_prob=True, + ) + + +class TorchLMHeadAPOZero(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + ref_bias: bool = False, + compute_nll_loss: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + ): + super().__init__() + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=ref_bias, dtype=dtype) + self.apo_loss = HFAPOZeroLoss( + ignore_index=ignore_index, + beta=beta, + use_ref_model=True, + compute_nll_loss=compute_nll_loss, + ).get_batch_loss_metrics + + def forward(self, x, ref_x, y): + return self.apo_loss( + self.lin.weight, + x, + y, + self.lin.bias, + ref_x, + self.ref_lin.weight, + self.ref_lin.bias, + average_log_prob=True, + ) + + +class TorchLMHeadAPODown(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + ref_bias: bool = False, + compute_nll_loss: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + ): + super().__init__() + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=ref_bias, dtype=dtype) + self.apo_loss = HFAPODownLoss( + ignore_index=ignore_index, + beta=beta, + use_ref_model=True, + compute_nll_loss=compute_nll_loss, + ).get_batch_loss_metrics + + def forward(self, x, ref_x, y): + return self.apo_loss( + self.lin.weight, + x, + y, + self.lin.bias, + ref_x, + self.ref_lin.weight, + self.ref_lin.bias, + average_log_prob=True, + ) + + +class TorchLMHeadSPPOHARD(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + ref_bias: bool = False, + compute_nll_loss: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + ): + super().__init__() + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=ref_bias, dtype=dtype) + self.sppo_hard = HFSPPPOHARDLoss( + ignore_index=ignore_index, + beta=beta, + use_ref_model=True, + compute_nll_loss=compute_nll_loss, + ).get_batch_loss_metrics + + def forward(self, x, ref_x, y): + return self.sppo_hard( + self.lin.weight, + x, + y, + self.lin.bias, + ref_x, + self.ref_lin.weight, + self.ref_lin.bias, + average_log_prob=True, + ) + + +class TorchLMHeadNCAPAIR(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + ref_bias: bool = False, + compute_nll_loss: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + ): + super().__init__() + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=ref_bias, dtype=dtype) + self.nca_pair = HFNCAPAIRLoss( + ignore_index=ignore_index, + beta=beta, + use_ref_model=True, + compute_nll_loss=compute_nll_loss, + ).get_batch_loss_metrics + + def forward(self, x, ref_x, y): + return self.nca_pair( + self.lin.weight, + x, + y, + self.lin.bias, + ref_x, + self.ref_lin.weight, + self.ref_lin.bias, + average_log_prob=True, + ) + + +class LigerLMHeadDPO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + ref_bias: bool = False, + compute_nll_loss: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + loss_type: str = "sigmoid", + ): + super().__init__() + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=ref_bias, dtype=dtype) + self.dpo_loss = LigerFusedLinearDPOLoss( + ignore_index=ignore_index, + beta=beta, + use_ref_model=True, + compute_nll_loss=compute_nll_loss, + average_log_prob=True, + loss_type=loss_type, + ) + + def forward(self, x, ref_x, y): + return self.dpo_loss( + self.lin.weight, + x, + y, + self.lin.bias, + ref_x, + self.ref_lin.weight, + self.ref_lin.bias, + ) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("ref_bias", [True, False]) +@pytest.mark.parametrize("compute_nll_loss", [True, False]) +@pytest.mark.parametrize("ignore_index, beta", [(-100, 0.1), (42, 0.2)]) +def test_correctness( + B, + T, + H, + V, + scalar, + dtype, + atol, + rtol, + bias, + ref_bias, + compute_nll_loss, + ignore_index, + beta, +): + B = 2 * B # dpo loss requires B to be even + + torch_lm_head_dpo = TorchLMHeadDPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ref_bias=ref_bias, + compute_nll_loss=compute_nll_loss, + ignore_index=ignore_index, + beta=beta, + ) + liger_lm_head_dpo = LigerLMHeadDPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ref_bias=ref_bias, + compute_nll_loss=compute_nll_loss, + ignore_index=ignore_index, + beta=beta, + ) + + torch_lm_head_dpo.lin.weight.data = liger_lm_head_dpo.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + torch_lm_head_dpo.ref_lin.weight.data = liger_lm_head_dpo.ref_lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + + if bias: + torch_lm_head_dpo.lin.bias.data = liger_lm_head_dpo.lin.bias.data = torch.randn(V, device=device, dtype=dtype) + if ref_bias: + torch_lm_head_dpo.ref_lin.bias.data = liger_lm_head_dpo.ref_lin.bias.data = torch.randn( + V, device=device, dtype=dtype + ) + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + ref_input = torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) * scalar + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + loss1, aggregated_aux_outputs1 = torch_lm_head_dpo(input1, ref_input, target) + loss2, aggregated_aux_outputs2 = liger_lm_head_dpo(input2, ref_input, target) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) + + for i in range(len(aggregated_aux_outputs1)): + if i > 4 and dtype == torch.bfloat16: + # numerical instability in bf16 for chosen_rewards and rejected_rewards + # temporary fix. TODO: investigate how to reduce numercial instabiltiy issue + assert_verbose_allclose( + aggregated_aux_outputs1[i], + aggregated_aux_outputs2[i], + atol=5e-1, + rtol=rtol, + ) + continue + assert_verbose_allclose( + aggregated_aux_outputs1[i], + aggregated_aux_outputs2[i], + atol=atol, + rtol=rtol, + ) + + loss1.backward() + loss2.backward() + + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose( + torch_lm_head_dpo.lin.weight.grad, + liger_lm_head_dpo.lin.weight.grad, + atol=atol, + rtol=rtol, + ) + if bias: + assert_verbose_allclose( + torch_lm_head_dpo.lin.bias.grad, + liger_lm_head_dpo.lin.bias.grad, + atol=atol, + rtol=rtol, + ) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (2, 2, 8, 8), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("ref_bias", [True, False]) +@pytest.mark.parametrize("compute_nll_loss", [True, False]) +def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias, compute_nll_loss): + B = 2 * B + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + ref_input = torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) * scalar + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + + _weight = torch.randn(V, H, device=device, dtype=dtype) + weight1 = _weight.detach().clone().requires_grad_(True) + weight2 = _weight.detach().clone().requires_grad_(True) + + _ref_weight = torch.randn(V, H, device=device, dtype=dtype) + ref_weight1 = _ref_weight.detach().clone().requires_grad_(True) + ref_weight2 = _ref_weight.detach().clone().requires_grad_(True) + + _bias = torch.randn(V, device=device, dtype=dtype) if bias else None + bias1 = _bias.detach().clone().requires_grad_(True) if bias else None + bias2 = _bias.detach().clone().requires_grad_(True) if bias else None + + _ref_bias = torch.randn(V, device=device, dtype=dtype) if ref_bias else None + ref_bias1 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None + ref_bias2 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None + + loss1, aggregated_aux_outputs1 = LigerFusedLinearDPOFunction.apply( + input1, + weight1, + target, + bias1, + ref_input, + ref_weight1, + ref_bias1, + -100, + 0.1, + compute_nll_loss, + ) + loss2, aggregated_aux_outputs2 = liger_fused_linear_dpo( + input2, + weight2, + target, + bias2, + ref_input, + ref_weight2, + ref_bias2, + -100, + 0.1, + compute_nll_loss, + ) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + loss1.backward() + loss2.backward() + + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose(weight1.grad, weight2.grad, atol=atol, rtol=rtol) + if bias: + assert_verbose_allclose(bias1.grad, bias2.grad, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("ref_bias", [True, False]) +@pytest.mark.parametrize("compute_nll_loss", [True, False]) +@pytest.mark.parametrize("ignore_index, beta", [(-100, 0.1), (42, 0.2)]) +@pytest.mark.parametrize("loss_type", ["apo_zero", "apo_down", "sppo_hard", "nca_pair"]) +def test_correctness_apo_loss_types( + B, + T, + H, + V, + scalar, + dtype, + atol, + rtol, + bias, + ref_bias, + compute_nll_loss, + ignore_index, + beta, + loss_type, +): + B = 2 * B # dpo loss requires B to be even + + # Select the appropriate HF reference implementation + if loss_type == "apo_zero": + torch_lm_head = TorchLMHeadAPOZero + elif loss_type == "apo_down": + torch_lm_head = TorchLMHeadAPODown + elif loss_type == "sppo_hard": + torch_lm_head = TorchLMHeadSPPOHARD + elif loss_type == "nca_pair": + torch_lm_head = TorchLMHeadNCAPAIR + else: + raise ValueError(f"Unsupported loss_type: {loss_type}") + + torch_lm_head_apo = torch_lm_head( + H=H, + V=V, + dtype=dtype, + bias=bias, + ref_bias=ref_bias, + compute_nll_loss=compute_nll_loss, + ignore_index=ignore_index, + beta=beta, + ) + liger_lm_head_apo = LigerLMHeadDPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ref_bias=ref_bias, + compute_nll_loss=compute_nll_loss, + ignore_index=ignore_index, + beta=beta, + loss_type=loss_type, + ) + + torch_lm_head_apo.lin.weight.data = liger_lm_head_apo.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + torch_lm_head_apo.ref_lin.weight.data = liger_lm_head_apo.ref_lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + + if bias: + torch_lm_head_apo.lin.bias.data = liger_lm_head_apo.lin.bias.data = torch.randn(V, device=device, dtype=dtype) + if ref_bias: + torch_lm_head_apo.ref_lin.bias.data = liger_lm_head_apo.ref_lin.bias.data = torch.randn( + V, device=device, dtype=dtype + ) + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + ref_input = torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) * scalar + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + loss1, aggregated_aux_outputs1 = torch_lm_head_apo(input1, ref_input, target) + loss2, aggregated_aux_outputs2 = liger_lm_head_apo(input2, ref_input, target) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) + + for i in range(len(aggregated_aux_outputs1)): + if i > 4 and dtype == torch.bfloat16: + # numerical instability in bf16 for chosen_rewards and rejected_rewards + # temporary fix. TODO: investigate how to reduce numerical instability issue + assert_verbose_allclose( + aggregated_aux_outputs1[i], + aggregated_aux_outputs2[i], + atol=5e-1, + rtol=rtol, + ) + continue + assert_verbose_allclose( + aggregated_aux_outputs1[i], + aggregated_aux_outputs2[i], + atol=atol, + rtol=rtol, + ) + + loss1.backward() + loss2.backward() + + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose( + torch_lm_head_apo.lin.weight.grad, + liger_lm_head_apo.lin.weight.grad, + atol=atol, + rtol=rtol, + ) + if bias: + assert_verbose_allclose( + torch_lm_head_apo.lin.bias.grad, + liger_lm_head_apo.lin.bias.grad, + atol=atol, + rtol=rtol, + ) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (2, 2, 8, 8), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("ref_bias", [True, False]) +@pytest.mark.parametrize("compute_nll_loss", [True, False]) +@pytest.mark.parametrize("loss_type", ["apo_zero", "apo_down", "sppo_hard", "nca_pair"]) +def test_correctness_functional_apo_loss_types( + B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias, compute_nll_loss, loss_type +): + B = 2 * B + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + ref_input = torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) * scalar + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + + _weight = torch.randn(V, H, device=device, dtype=dtype) + weight1 = _weight.detach().clone().requires_grad_(True) + weight2 = _weight.detach().clone().requires_grad_(True) + + _ref_weight = torch.randn(V, H, device=device, dtype=dtype) + ref_weight1 = _ref_weight.detach().clone().requires_grad_(True) + ref_weight2 = _ref_weight.detach().clone().requires_grad_(True) + + _bias = torch.randn(V, device=device, dtype=dtype) if bias else None + bias1 = _bias.detach().clone().requires_grad_(True) if bias else None + bias2 = _bias.detach().clone().requires_grad_(True) if bias else None + + _ref_bias = torch.randn(V, device=device, dtype=dtype) if ref_bias else None + ref_bias1 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None + ref_bias2 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None + + # Call with loss_type parameter for LigerFusedLinearDPOFunction + loss1, aggregated_aux_outputs1 = LigerFusedLinearDPOFunction.apply( + input1, + weight1, + target, + bias1, + ref_input, + ref_weight1, + ref_bias1, + -100, + 0.1, + compute_nll_loss, + True, # compiled + True, # use_ref_model + False, # average_log_prob + 1, # chunk_size + loss_type, # loss_type + ) + + # For comparison, create a LigerFusedLinearDPOLoss with the loss_type + dpo_loss_fn = LigerFusedLinearDPOLoss( + ignore_index=-100, + beta=0.1, + compute_nll_loss=compute_nll_loss, + loss_type=loss_type, + ) + + loss2, aggregated_aux_outputs2 = dpo_loss_fn( + weight2, + input2, + target, + bias2, + ref_input, + ref_weight2, + ref_bias2, + ) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + loss1.backward() + loss2.backward() + + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose(weight1.grad, weight2.grad, atol=atol, rtol=rtol) + if bias: + assert_verbose_allclose(bias1.grad, bias2.grad, atol=atol, rtol=rtol) + + +def test_invalid_loss_type(): + """Test that invalid loss types raise ValueError""" + with pytest.raises(ValueError, match="Unsupported loss_type"): + LigerFusedLinearDPOLoss(loss_type="invalid_loss_type") + + # Test that valid loss types don't raise errors + valid_loss_types = ["sigmoid", "apo_zero", "apo_down", "sppo_hard", "nca_pair"] + for loss_type in valid_loss_types: + # Should not raise an exception + loss_fn = LigerFusedLinearDPOLoss(loss_type=loss_type) + assert loss_fn.loss_type == loss_type diff --git a/test/chunked_loss/test_grpo_loss.py b/test/chunked_loss/test_grpo_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..59221a666a433741cdfddafacea43ab083a664dd --- /dev/null +++ b/test/chunked_loss/test_grpo_loss.py @@ -0,0 +1,993 @@ +import pytest +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss +from liger_kernel.chunked_loss.functional import liger_fused_linear_grpo +from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase +from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction +from liger_kernel.transformers.grpo_loss import _reduce_grpo_loss +from liger_kernel.transformers.grpo_loss import triton_grpo_loss +from liger_kernel.utils import infer_device +from test.utils import assert_verbose_allclose +from test.utils import set_seed + +device = infer_device() + +# set random seed globally +set_seed() + + +def sapo_loss_fn(importance_ratio: torch.Tensor, temperature: float) -> torch.Tensor: + """SAPO (Soft Adaptive Policy Optimization) loss function for torch reference. + + Reference: https://huggingface.co/papers/2511.20347 + TRL implementation: https://github.com/huggingface/trl/blob/1bd2a52ec2d8344050af736d60cdc735181ae4b8/trl/trainer/grpo_trainer.py#L1913 + """ + if temperature <= 0: + raise ValueError("sapo_temperature must be > 0.") + sigmoid_input = temperature * (importance_ratio - 1) + sigmoid_smoothed_loss = torch.sigmoid(sigmoid_input) + return sigmoid_smoothed_loss * 4 / temperature + + +class TorchLMHeadGRPO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + beta: float = 0.1, + epsilon_low: float = 0.2, + epsilon_high: float = 0.2, + temperature: float = 1.0, + use_ref_model: bool = True, + loss_type: str = "bnpo", + max_completion_length: int | None = None, + importance_sampling_level: str = "token", + sapo_temperature_pos: float = 1.0, + sapo_temperature_neg: float = 1.05, + delta: float | None = None, + use_bias_correction_kl: bool = False, + ): + super().__init__() + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.beta = beta + self.epsilon_low = epsilon_low + self.epsilon_high = epsilon_high + self.temperature = temperature + self.use_ref_model = use_ref_model + self.loss_type = loss_type + self.max_completion_length = max_completion_length + self.importance_sampling_level = importance_sampling_level + self.sapo_temperature_pos = sapo_temperature_pos + self.sapo_temperature_neg = sapo_temperature_neg + self.delta = delta + self.use_bias_correction_kl = use_bias_correction_kl + if self.loss_type == "dr_grpo": + assert self.max_completion_length is not None, "max_completion_length must be provided for dr_grpo" + + @staticmethod + def compute_per_token_components( + per_token_logps, + attention_mask, + advantages, + old_per_token_logps, + ref_per_token_logps, + epsilon_low, + epsilon_high, + beta, + importance_sampling_level, + loss_type: str = "grpo", + sapo_temperature_pos: float = 1.0, + sapo_temperature_neg: float = 1.05, + vllm_is_ratio=None, + delta=None, + use_bias_correction_kl=False, + ): + attention_mask = attention_mask.to(per_token_logps.dtype) + old_per_token_logps = ( + old_per_token_logps.float() if old_per_token_logps is not None else per_token_logps.detach() + ) + log_ratio = per_token_logps - old_per_token_logps + + if importance_sampling_level == "token": + log_importance_weights = log_ratio + elif importance_sampling_level == "sequence": + log_importance_weights = (log_ratio * attention_mask).sum(-1) / attention_mask.sum(-1).clamp(min=1.0) + log_importance_weights = log_importance_weights.unsqueeze(-1) + else: + raise ValueError( + f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' " + "and 'sequence'." + ) + + coef_1 = torch.exp(log_importance_weights) + expanded_advantages = advantages.unsqueeze(1) + + if loss_type == "sapo": + # SAPO: Soft Adaptive Policy Optimization + # Uses sigmoid-based soft gating instead of hard clipping + # Reference: https://github.com/huggingface/trl/blob/1bd2a52ec2d8344050af736d60cdc735181ae4b8/trl/trainer/grpo_trainer.py#L2037-L2046 + per_token_loss = torch.empty_like(coef_1) + advantages_expanded = expanded_advantages.expand_as(coef_1) + positive_advantages_mask = advantages_expanded > 0 + + per_token_loss[positive_advantages_mask] = sapo_loss_fn( + coef_1[positive_advantages_mask], sapo_temperature_pos + ) + per_token_loss[~positive_advantages_mask] = sapo_loss_fn( + coef_1[~positive_advantages_mask], sapo_temperature_neg + ) + per_token_loss = -per_token_loss * advantages_expanded + # SAPO doesn't use clipping metrics + is_lower_clipped = torch.zeros_like(coef_1, dtype=torch.bool) + is_upper_clipped = torch.zeros_like(coef_1, dtype=torch.bool) + elif loss_type == "cispo": + # CISPO: clip and detach the importance weights + upper_bound = epsilon_high + lower_bound = None + coef_2 = torch.clamp(coef_1, lower_bound, upper_bound).detach() + is_lower_clipped = torch.zeros_like(coef_1, dtype=torch.bool) + is_upper_clipped = coef_1 > upper_bound + # CISPO: clip and detach the importance weights, multiply by log probs + # Reference: https://github.com/huggingface/trl/blob/035c3ff151b953ca72cdfe0ee966bc1469a26fde/trl/trainer/grpo_trainer.py#L2030 + per_token_loss = -coef_2 * expanded_advantages * per_token_logps + else: + upper_bound = 1 + epsilon_high + lower_bound = 1 - epsilon_low + coef_2 = torch.clamp(coef_1, lower_bound, upper_bound) + is_lower_clipped = coef_1 < lower_bound + is_upper_clipped = coef_1 > upper_bound + if delta is not None: + coef_1 = torch.clamp(coef_1, max=delta) + per_token_loss1 = coef_1 * expanded_advantages + per_token_loss2 = coef_2 * expanded_advantages + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + + # Apply vLLM importance sampling correction BEFORE KL penalty + if vllm_is_ratio is not None: + per_token_loss = per_token_loss * vllm_is_ratio + + kl_div = None + if beta != 0.0: + ref_per_token_logps = ref_per_token_logps.float() + kl_div = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1.0 + if use_bias_correction_kl: + token_coef_1 = torch.exp(per_token_logps - old_per_token_logps) + kl_div = kl_div * token_coef_1 + per_token_loss = per_token_loss + beta * kl_div + + # Adjust clipping metric calculation based on importance sampling level + if importance_sampling_level == "token": + is_clipped = (is_lower_clipped & (expanded_advantages < 0)) | (is_upper_clipped & (expanded_advantages > 0)) + else: # sequence level + # For sequence level, coef_1 is shape (B, 1), advantages is shape (B,) + is_clipped = (is_lower_clipped & (expanded_advantages < 0)) | (is_upper_clipped & (expanded_advantages > 0)) + is_clipped = is_clipped.expand_as(attention_mask) + return per_token_loss, kl_div, is_clipped + + def forward( + self, + x, # Shape: [batch_size, seq_len, hidden_size] + selected_token_ids, # Shape: [batch_size, seq_len] + attention_mask, # Shape: [batch_size, seq_len] + advantages, # Shape: [batch_size,] + ref_per_token_logps=None, # Shape: [batch_size, seq_len] + old_per_token_logps=None, + ref_input=None, # Shape: [batch_size, seq_len, hidden_size] + vllm_is_ratio=None, # Shape: [batch_size, seq_len] or None + ): + logits = x @ self.lin.weight.t() + if self.lin.bias is not None: + logits = logits + self.lin.bias + if self.temperature != 1.0: + logits = logits / self.temperature + # Get log probabilities + log_probs = F.log_softmax(logits.float(), dim=-1) + + # Get chosen token probabilities + per_token_logps = log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(-1) + + # Get reference model probabilities, + if ref_per_token_logps is None: + if self.use_ref_model: + with torch.no_grad(): + ref_logits = ref_input @ self.ref_lin.weight.t() + if self.ref_lin.bias is not None: + ref_logits = ref_logits + self.ref_lin.bias.float() + if self.temperature != 1.0: + ref_logits = ref_logits / self.temperature + ref_log_probs = F.log_softmax(ref_logits.float(), dim=-1) + ref_per_token_logps = ref_log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze( + -1 + ) + else: + ref_per_token_logps = per_token_logps.detach() + + per_token_loss, kl_div, is_clipped = self.compute_per_token_components( + per_token_logps, + attention_mask, + advantages, + old_per_token_logps, + ref_per_token_logps, + self.epsilon_low, + self.epsilon_high, + self.beta, + self.importance_sampling_level, + self.loss_type, + self.sapo_temperature_pos, + self.sapo_temperature_neg, + vllm_is_ratio=vllm_is_ratio, + delta=self.delta, + use_bias_correction_kl=self.use_bias_correction_kl, + ) + + # Apply masking and calculate loss based on loss_type + if self.loss_type == "grpo" or self.loss_type == "sapo": + # SAPO uses same normalization as GRPO (per-sequence) + loss = ((per_token_loss * attention_mask).sum(-1) / torch.clamp(attention_mask.sum(-1), min=1.0)).mean() + elif self.loss_type == "bnpo": + loss = (per_token_loss * attention_mask).sum() / torch.clamp(attention_mask.sum(), min=1.0) + elif self.loss_type == "dr_grpo": + loss = (per_token_loss * attention_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) + elif self.loss_type == "dapo": + normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(attention_mask) + loss = (per_token_loss * attention_mask).sum() / normalizer + elif self.loss_type == "cispo": + normalizer = attention_mask.sum().clamp(min=1.0) + loss = (per_token_loss * attention_mask).sum() / normalizer + elif self.loss_type == "luspo": + loss = (per_token_loss * attention_mask.sum(-1, keepdim=True)).mean() + else: + raise ValueError(f"Unknown loss type: {self.loss_type}") + + # Compute metrics + metrics = [] + if self.beta != 0.0: + metrics.append(((kl_div * attention_mask).sum() / torch.clamp(attention_mask.sum(), min=1.0))) + metrics.append((is_clipped.float() * attention_mask).sum() / torch.clamp(attention_mask.sum(), min=1.0)) + return loss, metrics + + +class LigerLMHeadGRPO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + beta: float = 0.1, + epsilon_low: float = 0.2, + epsilon_high: float = 0.2, + temperature: float = 1.0, + use_ref_model: bool = True, + loss_type: str = "bnpo", + max_completion_length: int | None = None, + importance_sampling_level: str = "token", + sapo_temperature_pos: float = 1.0, + sapo_temperature_neg: float = 1.05, + delta: float | None = None, + use_bias_correction_kl: bool = False, + ): + super().__init__() + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.grpo_loss = LigerFusedLinearGRPOLoss( + beta=beta, + epsilon_low=epsilon_low, + epsilon_high=epsilon_high, + temperature=temperature, + use_ref_model=use_ref_model, + compiled=True, + loss_type=loss_type, + max_completion_length=max_completion_length, + importance_sampling_level=importance_sampling_level, + sapo_temperature_pos=sapo_temperature_pos, + sapo_temperature_neg=sapo_temperature_neg, + delta=delta, + use_bias_correction_kl=use_bias_correction_kl, + ) + + def forward( + self, + x, + selected_token_ids, + attention_mask, + advantages, + ref_per_token_logps=None, + old_per_token_logps=None, + ref_input=None, + vllm_is_ratio=None, + ): + return self.grpo_loss( + x, # _input + self.lin.weight, # weight + selected_token_ids, # selected_token_ids + attention_mask, # attention_mask + advantages, # advantages + self.lin.bias, # bias + ref_per_token_logps, # ref_per_token_logps + old_per_token_logps, # old_per_token_logps + ref_input, # ref_input + self.ref_lin.weight, # ref_weight + self.ref_lin.bias, # ref_bias + vllm_is_ratio=vllm_is_ratio, + ) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize( + "beta, epsilon_low, epsilon_high, temperature", + [ + # Standard settings + (0.1, 0.2, 0.2, 1.0), + (0.0, 0.1, 0.1, 2.0), + ], +) +@pytest.mark.parametrize( + "use_ref_model, use_ref_per_token_logps, old_per_token_logps", + [ + (True, True, True), + (True, False, False), + (False, False, True), + ], +) +@pytest.mark.parametrize("loss_type", ["bnpo", "grpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo"]) +@pytest.mark.parametrize("importance_sampling_level", ["token", "sequence"]) +@pytest.mark.parametrize("delta", [None, 2.0]) +def test_correctness( + B, + T, + H, + V, + scalar, + dtype, + atol, + rtol, + bias, + beta, + epsilon_low, + epsilon_high, + temperature, + use_ref_per_token_logps, + use_ref_model, + old_per_token_logps, + loss_type, + importance_sampling_level, + delta, +): + if importance_sampling_level == "sequence" and loss_type in ("cispo", "sapo"): + pytest.skip(f"Sequence-level importance sampling is not supported for loss_type='{loss_type}'") + if delta is not None and loss_type in ("cispo", "sapo"): + pytest.skip(f"delta is not supported for loss_type='{loss_type}'") + + # LUSPO's formula multiplies per_token_loss by seq_lens, amplifying torch.compile + # numerical differences by O(T). Relax tolerances to account for this amplification. + if loss_type == "luspo": + if dtype == torch.bfloat16: + atol = max(atol, 1.0) + rtol = max(rtol, 5.0) + else: + atol = max(atol, 1e-4) + rtol = max(rtol, 5e-3) + + # Reset torch compiler cache for each parameter of the test case + torch.compiler.reset() + max_completion_length = T if loss_type == "dr_grpo" else None + + torch_lm_head_grpo = TorchLMHeadGRPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + beta=beta, + epsilon_low=epsilon_low, + epsilon_high=epsilon_high, + temperature=temperature, + use_ref_model=use_ref_model, + loss_type=loss_type, + max_completion_length=max_completion_length, + importance_sampling_level=importance_sampling_level, + delta=delta, + ) + liger_lm_head_grpo = LigerLMHeadGRPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + beta=beta, + epsilon_low=epsilon_low, + epsilon_high=epsilon_high, + temperature=temperature, + use_ref_model=use_ref_model, + loss_type=loss_type, + max_completion_length=max_completion_length, + importance_sampling_level=importance_sampling_level, + delta=delta, + ) + + # Initialize weights + torch_lm_head_grpo.lin.weight.data = liger_lm_head_grpo.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + if bias: + torch_lm_head_grpo.lin.bias.data = liger_lm_head_grpo.lin.bias.data = torch.randn(V, device=device, dtype=dtype) + + # set ref weights to be close to the original weights + torch_lm_head_grpo.ref_lin.weight.data = liger_lm_head_grpo.ref_lin.weight.data = ( + torch_lm_head_grpo.lin.weight.data + torch.randn(V, H, device=device, dtype=dtype) * 0.01 + ) + if bias: + torch_lm_head_grpo.ref_lin.bias.data = liger_lm_head_grpo.ref_lin.bias.data = ( + torch_lm_head_grpo.lin.bias.data + torch.randn(V, device=device, dtype=dtype) * 0.01 + ) + + # Create inputs with shape [B, T, H] + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + # Create selected token ids with shape [B, T] + selected_token_ids = torch.randint(0, V, (B, T), device=device) + + # Compute per-token logps + with torch.no_grad(): + logits = _input @ torch_lm_head_grpo.lin.weight.t() + if torch_lm_head_grpo.lin.bias is not None: + logits = logits + torch_lm_head_grpo.lin.bias + logits = logits / temperature + logps = F.log_softmax(logits, dim=-1) + per_token_logps = logps.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(-1) + + # Create attention mask with random padding [B, T] + attention_mask = torch.ones(B, T, device=device) + num_elements_to_mask = torch.randint(1, B * T // 2, (1,)).item() + mask_indices = torch.randperm(B * T)[:num_elements_to_mask] + attention_mask.view(-1)[mask_indices] = 0 + + # Create advantages with shape [B] and ensure mixed signs for SAPO + advantages = torch.randn(B, device=device, dtype=dtype) + advantages[0] = -advantages[0].abs() + if B > 1: + advantages[1] = advantages[1].abs() + + ref_per_token_logps = None + ref_input = None + if use_ref_model and use_ref_per_token_logps: + # Create reference log probs with shape [B, T] + ref_per_token_logps = per_token_logps.detach() + torch.randn(B, T, device=device) * 0.01 + elif use_ref_model: + # Create reference inputs (optional) with shape [B, T, H] if ref_log_probs is None + ref_input = _input.detach() + torch.randn(B, T, H, device=device, dtype=dtype) * 0.01 + + if old_per_token_logps: + old_per_token_logps = per_token_logps.detach() + torch.randn(B, T, device=device) * 0.01 + else: + old_per_token_logps = None + + # Forward pass with reference model + loss1, aux1 = torch_lm_head_grpo( + input1, + selected_token_ids, + attention_mask, + advantages, + ref_per_token_logps=ref_per_token_logps, + old_per_token_logps=old_per_token_logps, + ref_input=ref_input, + ) + loss2, aux2 = liger_lm_head_grpo( + input2, + selected_token_ids, + attention_mask, + advantages, + ref_per_token_logps=ref_per_token_logps, + old_per_token_logps=old_per_token_logps, + ref_input=ref_input, + ) + # Check losses match + assert not torch.isnan(loss1) + assert not torch.isnan(loss2) + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + # Check metrics match + assert len(aux1) == len(aux2) + # aggregated metrics are unstable for bfloat16 + for metric1, metric2 in zip(aux1, aux2): + assert_verbose_allclose(metric1, metric2, atol=atol, rtol=rtol) + + # Backward pass + loss1.backward() + loss2.backward() + + # Check gradients match for loss_type + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose( + torch_lm_head_grpo.lin.weight.grad, + liger_lm_head_grpo.lin.weight.grad, + atol=atol, + rtol=rtol, + ) + if bias: + assert_verbose_allclose( + torch_lm_head_grpo.lin.bias.grad, + liger_lm_head_grpo.lin.bias.grad, + atol=atol, + rtol=rtol, + ) + + +@pytest.mark.parametrize("loss_type", ["grpo", "dapo"]) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float32, 1e-5, 5e-4), + ], +) +def test_correctness_with_bias_correction_kl(loss_type, dtype, atol, rtol): + """Test use_bias_correction_kl (importance-sampling-corrected KL from DeepSeek-V3.2).""" + B, T, H, V = 3, 47, 31, 123 + beta = 0.1 # Must be non-zero for KL to matter + torch.compiler.reset() + + torch_lm_head_grpo = TorchLMHeadGRPO( + H=H, + V=V, + dtype=dtype, + beta=beta, + loss_type=loss_type, + use_bias_correction_kl=True, + ) + liger_lm_head_grpo = LigerLMHeadGRPO( + H=H, + V=V, + dtype=dtype, + beta=beta, + loss_type=loss_type, + use_bias_correction_kl=True, + ) + + torch_lm_head_grpo.lin.weight.data = liger_lm_head_grpo.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + torch_lm_head_grpo.ref_lin.weight.data = liger_lm_head_grpo.ref_lin.weight.data = ( + torch_lm_head_grpo.lin.weight.data + torch.randn(V, H, device=device, dtype=dtype) * 0.01 + ) + + _input = torch.randn(B, T, H, device=device, dtype=dtype) + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + selected_token_ids = torch.randint(0, V, (B, T), device=device) + attention_mask = torch.ones(B, T, device=device, dtype=dtype) + attention_mask[:, -10:] = 0 + advantages = torch.randn(B, device=device, dtype=torch.float32) + old_per_token_logps = torch.randn(B, T, device=device, dtype=torch.float32) + + loss1, metrics1 = torch_lm_head_grpo( + input1, + selected_token_ids, + attention_mask, + advantages, + old_per_token_logps=old_per_token_logps, + ref_input=input1.detach(), + ) + loss2, metrics2 = liger_lm_head_grpo( + input2, + selected_token_ids, + attention_mask, + advantages, + old_per_token_logps=old_per_token_logps, + ref_input=input2.detach(), + ) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + loss1.backward() + loss2.backward() + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose( + torch_lm_head_grpo.lin.weight.grad, + liger_lm_head_grpo.lin.weight.grad, + atol=atol, + rtol=rtol, + ) + + +@pytest.mark.parametrize("loss_type", ["bnpo", "grpo", "dapo", "cispo", "sapo", "luspo"]) +@pytest.mark.parametrize("beta", [0.0, 0.1]) +def test_correctness_with_vllm_is_ratio(loss_type, beta): + """Test vllm_is_ratio correctness against torch reference, and 1D/2D shape equivalence.""" + torch.compiler.reset() + B, T, H, V = 4, 32, 64, 128 + dtype = torch.float32 + atol, rtol = 1e-5, 5e-4 + + _weight = torch.randn(V, H, device=device, dtype=dtype) + _input = torch.randn(B, T, H, device=device, dtype=dtype) + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + selected_token_ids = torch.randint(0, V, (B, T), device=device) + attention_mask = torch.ones(B, T, device=device) + attention_mask[:, -5:] = 0 + advantages = torch.randn(B, device=device, dtype=dtype) + advantages[0] = -advantages[0].abs() # ensure mixed signs for SAPO + + vllm_is_ratio = torch.rand(B, T, device=device, dtype=torch.float32) * 0.999 + 0.001 + + torch_lm = TorchLMHeadGRPO(H=H, V=V, dtype=dtype, beta=beta, loss_type=loss_type, use_ref_model=False) + liger_lm = LigerLMHeadGRPO(H=H, V=V, dtype=dtype, beta=beta, loss_type=loss_type, use_ref_model=False) + torch_lm.lin.weight.data = liger_lm.lin.weight.data = _weight.clone() + + loss1, aux1 = torch_lm(input1, selected_token_ids, attention_mask, advantages, vllm_is_ratio=vllm_is_ratio) + loss2, aux2 = liger_lm(input2, selected_token_ids, attention_mask, advantages, vllm_is_ratio=vllm_is_ratio) + + assert not torch.isnan(loss1) + assert not torch.isnan(loss2) + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + for m1, m2 in zip(aux1, aux2): + assert_verbose_allclose(m1, m2, atol=atol, rtol=rtol) + + loss1.backward() + loss2.backward() + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose(torch_lm.lin.weight.grad, liger_lm.lin.weight.grad, atol=atol, rtol=rtol) + + # Verify 1D (B,) gives same result as (B, 1) + uniform_val = 0.42 + input3 = _input.detach().clone().requires_grad_(True) + input4 = _input.detach().clone().requires_grad_(True) + liger3 = LigerLMHeadGRPO(H=H, V=V, dtype=dtype, beta=beta, loss_type=loss_type, use_ref_model=False) + liger4 = LigerLMHeadGRPO(H=H, V=V, dtype=dtype, beta=beta, loss_type=loss_type, use_ref_model=False) + liger3.lin.weight.data = liger4.lin.weight.data = _weight.clone() + + loss3, _ = liger3( + input3, + selected_token_ids, + attention_mask, + advantages, + vllm_is_ratio=torch.full((B,), uniform_val, device=device), + ) + loss4, _ = liger4( + input4, + selected_token_ids, + attention_mask, + advantages, + vllm_is_ratio=torch.full((B, 1), uniform_val, device=device), + ) + assert_verbose_allclose(loss3, loss4, atol=1e-5, rtol=1e-5) + loss3.backward() + loss4.backward() + assert_verbose_allclose(input3.grad, input4.grad, atol=1e-5, rtol=1e-5) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +def test_functional_correctness( + B, + T, + H, + V, + scalar, + dtype, + atol, + rtol, + bias, +): + # Reset torch compiler cache for each parameter of the test case + torch.compiler.reset() + max_completion_length = T + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + _weight = torch.randn(V, H, device=device, dtype=dtype) * scalar + weight1 = _weight.detach().clone().requires_grad_(True) + weight2 = _weight.detach().clone().requires_grad_(True) + + selected_token_ids = torch.randint(0, V, (B, T), device=device) + + attention_mask = torch.ones(B, T, device=device) + + advantages = torch.rand(B, device=device, dtype=dtype) + + if bias: + _bias = torch.randn(V, device=device, dtype=dtype) * scalar + bias1 = _bias.detach().clone().requires_grad_(True) + bias2 = _bias.detach().clone().requires_grad_(True) + else: + bias1 = None + bias2 = None + + ref_input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + + _ref_weight = _weight.detach() + torch.randn(V, H, device=device, dtype=dtype) * 0.01 + ref_weight1 = _ref_weight.detach().clone().requires_grad_(True) + ref_weight2 = _ref_weight.detach().clone().requires_grad_(True) + + if bias: + _ref_bias = _bias.detach() + torch.randn(V, device=device, dtype=dtype) * 0.01 + ref_bias1 = _ref_bias.detach().clone().requires_grad_(True) + ref_bias2 = _ref_bias.detach().clone().requires_grad_(True) + else: + ref_bias1 = None + ref_bias2 = None + + old_per_token_logps = None + ref_per_token_logps = None + + loss1, aux1 = liger_fused_linear_grpo( + input1, + weight1, + selected_token_ids, + attention_mask, + advantages, + bias1, + ref_per_token_logps, + old_per_token_logps, + ref_input, + ref_weight1, + ref_bias1, + 0.04, + 0.2, + 0.2, + "bnpo", + max_completion_length, + "token", + 1.0, + False, + True, + 1, + ) + + loss2, aux2 = LigerFusedLinearGRPOFunction.apply( + input2, + weight2, + selected_token_ids, + attention_mask, + advantages, + bias2, + ref_per_token_logps, + old_per_token_logps, + ref_input, + ref_weight2, + ref_bias2, + 0.04, + 0.2, + 0.2, + "bnpo", + max_completion_length, + "token", + 1.0, + False, + True, + 1, + ) + + assert not torch.isnan(loss1) + assert not torch.isnan(loss2) + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + # Check metrics match + assert len(aux1) == len(aux2) + # aggregated metrics are unstable for bfloat16 + for metric1, metric2 in zip(aux1, aux2): + assert_verbose_allclose(metric1, metric2, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize("loss_type", ["grpo", "bnpo", "dr_grpo", "dapo", "luspo"]) +def test_reduce_grpo_loss_matches_reference(loss_type): + torch.manual_seed(0) + per_token_loss = torch.randn(3, 5) + mask = torch.randint(0, 2, (3, 5), device=per_token_loss.device, dtype=torch.long) + mask[:, 0] = 1 # ensure at least one valid token per sequence + max_completion_length = 5 if loss_type == "dr_grpo" else None + + reduced = _reduce_grpo_loss(per_token_loss, mask, loss_type, max_completion_length) + + mask_f = mask.to(per_token_loss.dtype) + if loss_type == "grpo": + expected = ((per_token_loss * mask_f).sum(-1) / mask_f.sum(-1).clamp(min=1.0)).mean() + elif loss_type == "bnpo": + expected = (per_token_loss * mask_f).sum() / mask_f.sum().clamp(min=1.0) + elif loss_type == "dr_grpo": + expected = (per_token_loss * mask_f).sum() / (per_token_loss.size(0) * max_completion_length) + elif loss_type == "luspo": + expected = (per_token_loss * mask_f.sum(-1, keepdim=True)).mean() + else: # dapo/cispo + expected = (per_token_loss * mask_f).sum() / mask_f.sum().clamp(min=1.0) + + assert_verbose_allclose(reduced, expected) + + +def test_reduce_grpo_loss_requires_max_completion_length(): + per_token_loss = torch.randn(2, 3) + mask = torch.ones_like(per_token_loss, dtype=torch.long) + reduced = _reduce_grpo_loss(per_token_loss, mask, "dr_grpo", max_completion_length=None) + expected = (per_token_loss * mask).sum() / (per_token_loss.size(0) * per_token_loss.size(1)) + assert_verbose_allclose(reduced, expected) + + +@pytest.mark.parametrize("loss_type", ["cispo", "sapo"]) +def test_sequence_level_rejects_unsupported_loss_types(loss_type): + """Sequence-level importance sampling should raise ValueError for cispo and sapo.""" + B, T, H, V = 2, 8, 16, 32 + dtype = torch.float32 + + liger_lm = LigerLMHeadGRPO( + H=H, + V=V, + dtype=dtype, + beta=0.0, + loss_type=loss_type, + use_ref_model=False, + importance_sampling_level="sequence", + ) + + _input = torch.randn(B, T, H, device=device, dtype=dtype).requires_grad_(True) + selected_token_ids = torch.randint(0, V, (B, T), device=device) + attention_mask = torch.ones(B, T, device=device) + advantages = torch.randn(B, device=device) + + with pytest.raises(ValueError, match="Sequence-level importance sampling is not supported"): + liger_lm(_input, selected_token_ids, attention_mask, advantages) + + +@pytest.mark.parametrize("loss_type,beta", [("bnpo", 0.0), ("dapo", 0.04)]) +def test_triton_grpo_loss_matches_reference(loss_type, beta): + pytest.importorskip("triton") + device = infer_device() + + B, T, V = 2, 4, 16 + logits = torch.randn(B, T + 1, V, device=device, dtype=torch.float32).contiguous() + completion_ids = torch.randint(0, V, (B, T), device=device) + completion_mask = torch.randint(0, 2, (B, T), device=device, dtype=torch.long) + completion_mask[:, 0] = 1 # ensure each sequence has at least one valid token + advantages = torch.randn(B, device=device, dtype=torch.float32) + old_logp = torch.randn(B, T, device=device, dtype=torch.float32) + ref_logp = torch.randn(B, T, device=device, dtype=torch.float32) if beta != 0.0 else None + + per_token_loss, per_token_kl, is_clipped = triton_grpo_loss( + logits=logits, + old_logp=old_logp, + ref_logp=ref_logp, + completion_ids=completion_ids, + advantages=advantages, + completion_mask=completion_mask, + temperature=1.0, + beta=beta, + eps_low=0.2, + eps_high=0.2, + inplace=False, + loss_type=loss_type, + max_completion_length=T, + reduce=False, + ) + + logits_main = logits[:, :-1, :] + log_probs = torch.log_softmax(logits_main, dim=-1) + per_token_logps = log_probs.gather(dim=-1, index=completion_ids.unsqueeze(-1)).squeeze(-1) + ref_tokens = ref_logp if ref_logp is not None else per_token_logps.detach() + reference_loss, reference_kl, reference_is_clipped = TorchLMHeadGRPO.compute_per_token_components( + per_token_logps, + completion_mask.float(), + advantages, + old_logp, + ref_tokens, + 0.2, + 0.2, + beta, + "token", + ) + + mask = completion_mask.float() + mask_bool = mask.bool() + assert_verbose_allclose(per_token_loss, reference_loss * mask) + assert torch.equal(is_clipped.bool()[mask_bool], reference_is_clipped[mask_bool]) + if beta != 0.0: + assert_verbose_allclose(per_token_kl, reference_kl * mask) + else: + assert per_token_kl is None + + reduced_loss, metrics = triton_grpo_loss( + logits=logits, + old_logp=old_logp, + ref_logp=ref_logp, + completion_ids=completion_ids, + advantages=advantages, + completion_mask=completion_mask, + temperature=1.0, + beta=beta, + eps_low=0.2, + eps_high=0.2, + inplace=False, + loss_type=loss_type, + max_completion_length=T, + reduce=True, + ) + expected_loss = _reduce_grpo_loss(reference_loss, completion_mask, loss_type, T) + assert_verbose_allclose(reduced_loss, expected_loss) + if beta != 0.0: + assert_verbose_allclose(metrics[0], _masked_mean(reference_kl, completion_mask)) + clip_metric = metrics[1] + else: + clip_metric = metrics[0] + assert_verbose_allclose(clip_metric, _masked_mean(reference_is_clipped.float(), completion_mask)) + + +def _reference_per_token_loss( + logits, + completion_ids, + completion_mask, + advantages, + old_logp, + ref_logp, + beta, + eps_low, + eps_high, + temperature=1.0, + delta=None, + use_bias_correction_kl=False, +): + logits = logits[:, :-1, :] / temperature + log_probs = torch.log_softmax(logits, dim=-1) + per_token_logps = log_probs.gather(-1, completion_ids.unsqueeze(-1)).squeeze(-1) + old = old_logp if old_logp is not None else per_token_logps.detach() + coef_1 = torch.exp(per_token_logps - old) + coef_2 = torch.clamp(coef_1, 1 - eps_low, 1 + eps_high) + if delta is not None: + coef_1 = torch.clamp(coef_1, max=delta) + per_token_loss1 = coef_1 * advantages.unsqueeze(1) + per_token_loss2 = coef_2 * advantages.unsqueeze(1) + per_token_loss = -torch.minimum(per_token_loss1, per_token_loss2) + is_clipped = per_token_loss1 < per_token_loss2 + mask = completion_mask.to(torch.bool) + per_token_loss = per_token_loss.masked_fill(~mask, 0.0) + is_clipped = is_clipped & mask + if beta != 0.0: + kl = torch.exp(ref_logp - per_token_logps) - (ref_logp - per_token_logps) - 1.0 + if use_bias_correction_kl: + kl = kl * torch.exp(per_token_logps - old) + kl = kl.masked_fill(~mask, 0.0) + per_token_loss = per_token_loss + beta * kl + else: + kl = None + return { + "per_token_loss": per_token_loss, + "kl": kl, + "is_clipped": is_clipped, + } + + +def _masked_mean(values, mask): + mask = mask.to(values.dtype) + return (values * mask).sum() / mask.sum().clamp(min=1.0) diff --git a/test/chunked_loss/test_jsd_loss.py b/test/chunked_loss/test_jsd_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..415a015c921d8027fe54e214a06e24725968a970 --- /dev/null +++ b/test/chunked_loss/test_jsd_loss.py @@ -0,0 +1,441 @@ +import math + +import pytest +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss import LigerFusedLinearJSDLoss +from liger_kernel.chunked_loss.functional import liger_fused_linear_jsd +from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction +from liger_kernel.utils import infer_device +from test.utils import HFDistillationLoss +from test.utils import assert_verbose_allclose +from test.utils import set_seed + +device = infer_device() + +# set random seed globally +set_seed() + + +class HFJSDLoss(HFDistillationLoss): + """ + Naive implementation of a distillation loss using Jensen-Shannon Divergence (JSD). + """ + + def __init__( + self, + temperature: float = 1.0, + ignore_index: int = -100, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + ): + super().__init__( + ignore_index=ignore_index, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + temperature=temperature, + ) + + def distillation_loss(self, student_logits, teacher_logits, target=None, ignore_index=-100, beta=0.5): + """ + Compute JSD loss (Jensen-Shannon Divergence Loss). + Args: + student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len, vocab_size). + teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size). + target (torch.Tensor): Target labels for masking. Shape: (batch_size * seq_len,). + ignore_index (int): Index to ignore in loss computation. + beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`. + Returns: + torch.Tensor: Jensen-Shannon Divergence loss + """ + student_log_probs = F.log_softmax(student_logits, dim=-1) + teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) + + if beta == 0: + jsd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True) + elif beta == 1: + jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True) + else: + log_mean_probs = torch.logsumexp( + torch.stack([student_log_probs + math.log(1 - beta), teacher_log_probs + math.log(beta)], dim=0), dim=0 + ) + student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="none", log_target=True) + teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="none", log_target=True) + jsd_loss = beta * teacher_kl + (1 - beta) * student_kl + + # Sum over vocab dimension + jsd_loss = jsd_loss.sum(dim=-1) + + # Apply ignore_index mask + if target is not None: + mask = target != ignore_index + jsd_loss = jsd_loss * mask.float() + num_valid_tokens = mask.sum().clamp_min(1) + return jsd_loss.sum() / num_valid_tokens + + return jsd_loss.sum() + + +class TorchLMHeadJSD(torch.nn.Module): + """Ground truth implementation of the linear fused with torch based jsd loss. + :param H: hidden size + :param V: vocab size + :param temperature: softmax temperature + :param weight_hard_loss: weight_hard_loss + :param weight_soft_loss: weight_soft_loss + """ + + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool, + device: torch.device, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + ): + super().__init__() + # smaller student model weights + self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=bias, dtype=dtype, device=device) + self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype, device=device) + self.beta = beta + self.jsd = HFJSDLoss( + ignore_index=ignore_index, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + temperature=temperature, + ).get_batch_loss_metrics + + def forward(self, student_input, teacher_input, target): + jsd_loss = self.jsd( + student_input, + self.student_lin.weight, + teacher_input, + self.teacher_lin.weight, + target, + self.student_lin.bias, + self.teacher_lin.bias, + beta=self.beta, + ) + return jsd_loss + + def backward_with_grad_and_value(self, student_input, teacher_input, target): + """ + Compute gradients using grad_and_value on NPU to match Liger implementation. + This method is used in tests on NPU devices to ensure consistency. + """ + # Use grad_and_value to compute gradients and loss + if self.student_lin.bias is not None: + + def loss_fn(student_input, student_weight, student_bias): + return self.jsd( + student_input, + student_weight, + teacher_input, + self.teacher_lin.weight, + target, + student_bias, + self.teacher_lin.bias, + beta=self.beta, + ) + + (grad_input, grad_weight, grad_bias), loss = torch.func.grad_and_value(loss_fn, argnums=(0, 1, 2))( + student_input, self.student_lin.weight, self.student_lin.bias + ) + + # Set gradients + student_input.grad = grad_input + self.student_lin.weight.grad = grad_weight + self.student_lin.bias.grad = grad_bias + else: + + def loss_fn(student_input, student_weight): + return self.jsd( + student_input, + student_weight, + teacher_input, + self.teacher_lin.weight, + target, + None, # student_bias is None when bias=False + self.teacher_lin.bias, + beta=self.beta, + ) + + (grad_input, grad_weight), loss = torch.func.grad_and_value(loss_fn, argnums=(0, 1))( + student_input, self.student_lin.weight + ) + + # Set gradients + student_input.grad = grad_input + self.student_lin.weight.grad = grad_weight + + return loss + + +class LigerLMHeadJSD(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool, + device: torch.device, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + ): + super().__init__() + # smaller student model weights + self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=bias, dtype=dtype, device=device) + self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype, device=device) + self.chunked_jsd = LigerFusedLinearJSDLoss( + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + ignore_index=ignore_index, + temperature=temperature, + beta=beta, + ) + + def forward(self, student_input, teacher_input, target): + return self.chunked_jsd( + student_input, + self.student_lin.weight, + teacher_input, + self.teacher_lin.weight, + target, + self.student_lin.bias, + self.teacher_lin.bias, + ) + + +############################################################################# +# Test the correctness of the fused linear JSD +############################################################################# + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize( + "temperature, weight_hard_loss, weight_soft_loss, beta", + [ + (1.0, 0.5, 0.5, 0.5), + (2.0, 0.0, 1.0, 0.8), + (0.5, 1.0, 0.0, 0.2), + ], +) +@pytest.mark.parametrize("ignore_index", [-100, 42]) +def test_correctness( + B, + T, + H, + V, + scalar, + dtype, + atol, + rtol, + bias, + temperature, + weight_hard_loss, + weight_soft_loss, + beta, + ignore_index, +): + torch_lm_head_jsd = TorchLMHeadJSD( + H=H, + V=V, + dtype=dtype, + bias=bias, + device=device, + temperature=temperature, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + beta=beta, + ignore_index=ignore_index, + ) + liger_lm_head_jsd = LigerLMHeadJSD( + H=H, + V=V, + dtype=dtype, + bias=bias, + device=device, + temperature=temperature, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + beta=beta, + ignore_index=ignore_index, + ) + + torch_lm_head_jsd.student_lin.weight.data = liger_lm_head_jsd.student_lin.weight.data = torch.rand( + V, H // 2, device=device, dtype=dtype + ) + torch_lm_head_jsd.teacher_lin.weight.data = liger_lm_head_jsd.teacher_lin.weight.data = torch.rand( + V, H, device=device, dtype=dtype + ) + + if bias: + torch_lm_head_jsd.student_lin.bias.data = liger_lm_head_jsd.student_lin.bias.data = torch.rand( + V, device=device, dtype=dtype + ) + torch_lm_head_jsd.teacher_lin.bias.data = liger_lm_head_jsd.teacher_lin.bias.data = torch.rand( + V, device=device, dtype=dtype + ) + + _tensor = torch.rand(B * T, H // 2, device=device, dtype=dtype) * scalar + student_input1 = _tensor.detach().clone().requires_grad_(True) + student_input2 = _tensor.detach().clone().requires_grad_(True) + + teacher_input = torch.rand(B * T, H, device=device, dtype=dtype) * scalar + + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target[indices_to_assign] = ignore_index + + # Assign some random number of elements as ignore_index + # On NPU, use grad_and_value for reference implementation to match Liger implementation + if device == "npu": + loss1 = torch_lm_head_jsd.backward_with_grad_and_value(student_input1, teacher_input, target) + loss2 = liger_lm_head_jsd(student_input2, teacher_input, target) + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + loss2.backward() + else: + loss1 = torch_lm_head_jsd(student_input1, teacher_input, target) + loss2 = liger_lm_head_jsd(student_input2, teacher_input, target) + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + loss1.backward() + loss2.backward() + + assert_verbose_allclose(student_input1.grad, student_input2.grad, atol=atol, rtol=rtol) + + assert_verbose_allclose( + torch_lm_head_jsd.student_lin.weight.grad, + liger_lm_head_jsd.student_lin.weight.grad, + atol=atol, + rtol=rtol, + ) + + if bias: + assert_verbose_allclose( + torch_lm_head_jsd.student_lin.bias.grad, + liger_lm_head_jsd.student_lin.bias.grad, + atol=atol, + rtol=rtol, + ) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (2, 2, 8, 8), + (9, 7, 41, 41), + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-2), + (1.0, torch.float32, 1e-4, 5e-3), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize( + "temperature, weight_hard_loss, weight_soft_loss, beta, ignore_index", + [(1.0, 0.5, 0.5, 0.5, -100), (2.0, 0.1, 0.9, 0.5, 42)], +) +def test_correctness_functional( + B, + T, + H, + V, + scalar, + dtype, + bias, + weight_hard_loss, + weight_soft_loss, + beta, + ignore_index, + temperature, + atol, + rtol, +): + _weight = torch.rand(V, H // 2, device=device, dtype=dtype) + student_weight1 = _weight.detach().clone().requires_grad_(True) + student_weight2 = _weight.detach().clone().requires_grad_(True) + teacher_weight = torch.rand(V, H, device=device, dtype=dtype) + + if bias: + _bias = torch.rand(V, device=device, dtype=dtype) + student_bias1 = _bias.detach().clone().requires_grad_(True) + student_bias2 = _bias.detach().clone().requires_grad_(True) + teacher_bias = torch.rand(V, device=device, dtype=dtype) + else: + student_bias1 = student_bias2 = teacher_bias = None + + _tensor = torch.rand(B * T, H // 2, device=device, dtype=dtype) * scalar + student_input1 = _tensor.detach().clone().requires_grad_(True) + student_input2 = _tensor.detach().clone().requires_grad_(True) + teacher_input = torch.rand(B * T, H, device=device, dtype=dtype) * scalar + + label = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + output1 = liger_fused_linear_jsd( + student_input1, + student_weight1, + teacher_input, + teacher_weight, + label, + student_bias1, + teacher_bias, + weight_hard_loss, + weight_soft_loss, + beta, + ignore_index, + temperature, + ) + output2 = LigerFusedLinearJSDFunction.apply( + student_input2, + student_weight2, + teacher_input, + teacher_weight, + label, + student_bias2, + teacher_bias, + weight_hard_loss, + weight_soft_loss, + beta, + ignore_index, + temperature, + ) + + assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol) + + output1.backward() + output2.backward() + + assert_verbose_allclose(student_input1.grad, student_input2.grad, atol=atol, rtol=rtol) + + assert_verbose_allclose(student_weight1.grad, student_weight2.grad, atol=atol, rtol=rtol) + + if bias: + assert_verbose_allclose(student_bias1.grad, student_bias2.grad, atol=atol, rtol=rtol) diff --git a/test/chunked_loss/test_kto_loss.py b/test/chunked_loss/test_kto_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..593159aa60d7b93f1093c9f94beca49496886fc8 --- /dev/null +++ b/test/chunked_loss/test_kto_loss.py @@ -0,0 +1,434 @@ +import pytest +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss import LigerFusedLinearKTOLoss +from liger_kernel.chunked_loss.functional import liger_fused_linear_kto +from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction +from liger_kernel.utils import infer_device +from test.utils import HFAlignmentLoss +from test.utils import assert_verbose_allclose +from test.utils import set_seed + +device = infer_device() + +# set random seed globally +set_seed(0) + + +class HFKTOLoss(HFAlignmentLoss): + """ + Implementation of the Kahneman-Tversky Optimization (KTO) loss, + adapted from Hugging Face's implementation. + Reference: https://github.com/huggingface/trl/blob/main/trl/trainer/kto_trainer.py + """ + + def __init__( + self, + ignore_index: int = -100, + beta: float = 0.1, + use_ref_model: bool = True, + ): + super().__init__( + beta=beta, + ignore_index=ignore_index, + use_ref_model=use_ref_model, + unpaired=True, + compute_nll_loss=False, + ) + + def alignment_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + ref_chosen_logps: torch.FloatTensor, + ref_rejected_logps: torch.FloatTensor, + kl: torch.FloatTensor = None, + ): + """Compute KTO loss for a batch of policy log probabilities. + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + ref_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) + ref_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) + Returns: + The losses tensor contains the KTO loss for each example in the batch. + """ + if kl is None: + kl = torch.zeros(1).to(policy_chosen_logps.device) + + # Chosen losses + chosen_logratios = policy_chosen_logps - ref_chosen_logps + if policy_chosen_logps.shape[0] != 0 or ref_chosen_logps.shape[0] != 0: + # Eqn (7) of the KTO paper (https://huggingface.co/papers/2402.01306) + chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl)) + + else: + # lists can't be empty -- if they are, then accelerate.gather will hang + chosen_losses = torch.Tensor([]).to(policy_chosen_logps.device) + + # Rejected losses + rejected_logratios = policy_rejected_logps - ref_rejected_logps + if policy_rejected_logps.shape[0] != 0 or ref_rejected_logps.shape[0] != 0: + rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios)) + else: + # lists can't be empty -- if they are, then accelerate.gather will hang + rejected_losses = torch.Tensor([]).to(policy_rejected_logps.device) + + losses = torch.cat( + (chosen_losses, rejected_losses), + 0, + ) + + chosen_rewards = self.beta * chosen_logratios + rejected_rewards = self.beta * rejected_logratios + + return losses, chosen_rewards.sum(), rejected_rewards.sum() + + +class TorchLMHeadKTO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + ref_bias: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + ): + super().__init__() + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=ref_bias, dtype=dtype) + self.KTO_loss = HFKTOLoss( + ignore_index=ignore_index, + beta=beta, + use_ref_model=True, + ).get_batch_loss_metrics + + def forward(self, x, ref_x, y, preference_labels, kl=None): + return self.KTO_loss( + weight=self.lin.weight, + _input=x, + target=y, + bias=self.lin.bias, + ref_input=ref_x, + ref_weight=self.ref_lin.weight, + ref_bias=self.ref_lin.bias, + preference_labels=preference_labels, + kl=kl, + average_log_prob=True, + ) + + +class LigerLMHeadKTO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + ref_bias: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + ): + super().__init__() + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=ref_bias, dtype=dtype) + self.KTO_loss = LigerFusedLinearKTOLoss( + ignore_index=ignore_index, + beta=beta, + use_ref_model=True, + average_log_prob=True, + ) + + def forward(self, x, ref_x, y, preference_labels, kl=None): + return self.KTO_loss( + _input=x, + lin_weight=self.lin.weight, + target=y, + preference_labels=preference_labels, + bias=self.lin.bias, + ref_input=ref_x, + ref_weight=self.ref_lin.weight, + ref_bias=self.ref_lin.bias, + kl=kl, + ) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("ref_bias", [True, False]) +@pytest.mark.parametrize("ignore_index, beta", [(-100, 0.1), (42, 0.2)]) +def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias, ignore_index, beta): + # Preference labels shape: [B] + # Create binary preference labels (0 or 1) for each sequence in the batch + # Used to indicate preferred sequences (1) vs non-preferred sequences (0) + preference_labels = torch.randint(2, (B,), dtype=torch.bool, device=device, requires_grad=False) + num_chosen_samples = preference_labels.sum() + num_rejected_samples = len(preference_labels) - num_chosen_samples + + # Precomputed KL divergence between policy and reference distributions + kl = torch.randn(1, device=device, dtype=dtype) + + torch_lm_head_KTO = TorchLMHeadKTO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ref_bias=ref_bias, + ignore_index=ignore_index, + beta=beta, + ) + liger_lm_head_KTO = LigerLMHeadKTO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ref_bias=ref_bias, + ignore_index=ignore_index, + beta=beta, + ) + + torch_lm_head_KTO.lin.weight.data = liger_lm_head_KTO.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + torch_lm_head_KTO.ref_lin.weight.data = liger_lm_head_KTO.ref_lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + + if bias: + torch_lm_head_KTO.lin.bias.data = liger_lm_head_KTO.lin.bias.data = torch.randn(V, device=device, dtype=dtype) + if ref_bias: + torch_lm_head_KTO.ref_lin.bias.data = liger_lm_head_KTO.ref_lin.bias.data = torch.randn( + V, device=device, dtype=dtype + ) + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + ref_input = torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) * scalar + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + loss1, aggregated_aux_outputs1 = torch_lm_head_KTO( + x=input1, ref_x=ref_input, y=target, preference_labels=preference_labels, kl=kl + ) + loss2, aggregated_aux_outputs2 = liger_lm_head_KTO( + x=input2, ref_x=ref_input, y=target, preference_labels=preference_labels, kl=kl + ) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) + + # Metrics tests are flaky for bf16 due to precision issues + if dtype == torch.float32: + # chosen_logps + chosen_logps_mean1 = aggregated_aux_outputs1[0] / ((num_chosen_samples) + 1e-20) + chosen_logps_mean2 = aggregated_aux_outputs2[0] / ((num_chosen_samples) + 1e-20) + assert_verbose_allclose(chosen_logps_mean1, chosen_logps_mean2, atol=atol, rtol=rtol) + + # chosen_logits + chosen_logits_mean1 = aggregated_aux_outputs1[2] / ((num_chosen_samples * T * V) + 1e-20) + chosen_logits_mean2 = aggregated_aux_outputs2[2] / ((num_chosen_samples * T * V) + 1e-20) + assert_verbose_allclose(chosen_logits_mean1, chosen_logits_mean2, atol=atol, rtol=rtol) + + # chosen_rewards + chosen_rewards_mean1 = aggregated_aux_outputs1[4] / ((num_chosen_samples) + 1e-20) + chosen_rewards_mean2 = aggregated_aux_outputs2[4] / ((num_chosen_samples) + 1e-20) + assert_verbose_allclose(chosen_rewards_mean1, chosen_rewards_mean2, atol=atol, rtol=rtol) + + # rejected_logps + rejected_logps_mean1 = aggregated_aux_outputs1[1] / ((num_rejected_samples) + 1e-20) + rejected_logps_mean2 = aggregated_aux_outputs2[1] / ((num_rejected_samples) + 1e-20) + assert_verbose_allclose(rejected_logps_mean1, rejected_logps_mean2, atol=atol, rtol=rtol) + + # rejected_logits + rejected_logits_mean1 = aggregated_aux_outputs1[3] / ((num_rejected_samples * T * V) + 1e-20) + rejected_logits_mean2 = aggregated_aux_outputs2[3] / ((num_rejected_samples * T * V) + 1e-20) + assert_verbose_allclose(rejected_logits_mean1, rejected_logits_mean2, atol=atol, rtol=rtol) + + # rejected_rewards + rejected_rewards_mean1 = aggregated_aux_outputs1[5] / ((num_rejected_samples) + 1e-20) + rejected_rewards_mean2 = aggregated_aux_outputs2[5] / ((num_rejected_samples) + 1e-20) + assert_verbose_allclose(rejected_rewards_mean1, rejected_rewards_mean2, atol=atol, rtol=rtol) + + loss1.backward() + loss2.backward() + + assert_verbose_allclose(input1, input2, atol=atol, rtol=rtol) + assert_verbose_allclose(torch_lm_head_KTO.lin.weight, liger_lm_head_KTO.lin.weight, atol=atol, rtol=rtol) + + if bias: + assert_verbose_allclose(torch_lm_head_KTO.lin.bias, liger_lm_head_KTO.lin.bias, atol=atol, rtol=rtol) + + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose( + torch_lm_head_KTO.lin.weight.grad, + liger_lm_head_KTO.lin.weight.grad, + atol=atol, + rtol=rtol, + ) + if bias: + assert_verbose_allclose( + torch_lm_head_KTO.lin.bias.grad, + liger_lm_head_KTO.lin.bias.grad, + atol=atol, + rtol=rtol, + ) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (2, 2, 8, 8), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("ref_bias", [True, False]) +def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias): + # Preference labels shape: [B] + # Create binary preference labels (0 or 1) for each sequence in the batch + # Used to indicate preferred sequences (1) vs non-preferred sequences (0) + preference_labels = torch.randint(2, (B,), dtype=torch.bool, device=device) + num_chosen_samples = preference_labels.sum() + num_rejected_samples = len(preference_labels) - num_chosen_samples + + # Precomputed KL divergence between policy and reference distributions + kl = torch.randn(1, device=device, dtype=dtype) + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + ref_input = torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) * scalar + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + + _weight = torch.randn(V, H, device=device, dtype=dtype) + weight1 = _weight.detach().clone().requires_grad_(True) + weight2 = _weight.detach().clone().requires_grad_(True) + + _ref_weight = torch.randn(V, H, device=device, dtype=dtype) + ref_weight1 = _ref_weight.detach().clone().requires_grad_(True) + ref_weight2 = _ref_weight.detach().clone().requires_grad_(True) + + _bias = torch.randn(V, device=device, dtype=dtype) if bias else None + bias1 = _bias.detach().clone().requires_grad_(True) if bias else None + bias2 = _bias.detach().clone().requires_grad_(True) if bias else None + + _ref_bias = torch.randn(V, device=device, dtype=dtype) if ref_bias else None + ref_bias1 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None + ref_bias2 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None + + loss1, aggregated_aux_outputs1 = LigerFusedLinearKTOFunction.apply( + input1, + weight1, + target, + preference_labels, + bias1, + ref_input, + ref_weight1, + ref_bias1, + kl, + ) + loss2, aggregated_aux_outputs2 = liger_fused_linear_kto( + input2, + weight2, + target, + preference_labels, + bias2, + ref_input, + ref_weight2, + ref_bias2, + kl, + ) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) + + # Metrics tests are flaky for bf16 due to precision issues + if dtype == torch.float32: + # chosen_logps + chosen_logps_mean1 = aggregated_aux_outputs1[0] / ((num_chosen_samples) + 1e-20) + chosen_logps_mean2 = aggregated_aux_outputs2[0] / ((num_chosen_samples) + 1e-20) + assert_verbose_allclose(chosen_logps_mean1, chosen_logps_mean2, atol=atol, rtol=rtol) + + # chosen_logits + chosen_logits_mean1 = aggregated_aux_outputs1[2] / ((num_chosen_samples * T * V) + 1e-20) + chosen_logits_mean2 = aggregated_aux_outputs2[2] / ((num_chosen_samples * T * V) + 1e-20) + assert_verbose_allclose(chosen_logits_mean1, chosen_logits_mean2, atol=atol, rtol=rtol) + + # chosen_rewards + chosen_rewards_mean1 = aggregated_aux_outputs1[4] / ((num_chosen_samples) + 1e-20) + chosen_rewards_mean2 = aggregated_aux_outputs2[4] / ((num_chosen_samples) + 1e-20) + assert_verbose_allclose(chosen_rewards_mean1, chosen_rewards_mean2, atol=atol, rtol=rtol) + + # rejected_logps + rejected_logps_mean1 = aggregated_aux_outputs1[1] / ((num_rejected_samples) + 1e-20) + rejected_logps_mean2 = aggregated_aux_outputs2[1] / ((num_rejected_samples) + 1e-20) + assert_verbose_allclose(rejected_logps_mean1, rejected_logps_mean2, atol=atol, rtol=rtol) + + # rejected_logits + rejected_logits_mean1 = aggregated_aux_outputs1[3] / ((num_rejected_samples * T * V) + 1e-20) + rejected_logits_mean2 = aggregated_aux_outputs2[3] / ((num_rejected_samples * T * V) + 1e-20) + assert_verbose_allclose(rejected_logits_mean1, rejected_logits_mean2, atol=atol, rtol=rtol) + + # rejected_rewards + rejected_rewards_mean1 = aggregated_aux_outputs1[5] / ((num_rejected_samples) + 1e-20) + rejected_rewards_mean2 = aggregated_aux_outputs2[5] / ((num_rejected_samples) + 1e-20) + assert_verbose_allclose(rejected_rewards_mean1, rejected_rewards_mean2, atol=atol, rtol=rtol) + + loss1.backward() + loss2.backward() + + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose(weight1.grad, weight2.grad, atol=atol, rtol=rtol) + if bias: + assert_verbose_allclose(bias1.grad, bias2.grad, atol=atol, rtol=rtol) diff --git a/test/chunked_loss/test_orpo_loss.py b/test/chunked_loss/test_orpo_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..a5c43b1b0b7e0dd33ae77a83fd0b3586e9c81d7c --- /dev/null +++ b/test/chunked_loss/test_orpo_loss.py @@ -0,0 +1,266 @@ +from typing import Tuple + +import pytest +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss +from liger_kernel.chunked_loss.functional import liger_fused_linear_orpo +from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction +from liger_kernel.utils import infer_device +from test.utils import HFAlignmentLoss +from test.utils import assert_verbose_allclose +from test.utils import set_seed + +device = infer_device() + +# set random seed globally +set_seed() + + +class HFORPOLoss(HFAlignmentLoss): + """ + Implementation of the Odds Ratio Preference Optimization (ORPO) loss, + adapted from Hugging Face's implementation. + Reference: https://github.com/huggingface/trl/blob/main/trl/trainer/orpo_trainer.py + """ + + def __init__(self, ignore_index: int = -100, beta: float = 0.1): + super().__init__(beta=beta, ignore_index=ignore_index) + + def alignment_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + ]: + """Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities. + + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + + Returns: + A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). + The losses tensor contains the ORPO loss for each example in the batch. + The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. + The log odds ratio of the chosen responses over the rejected responses ratio for logging purposes. + The `log(sigmoid(log_odds_chosen))` for logging purposes. + """ + + # Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x) + log_odds = (policy_chosen_logps - policy_rejected_logps) - ( + torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps)) + ) + ratio = F.logsigmoid(log_odds) + losses = -self.beta * ratio + + chosen_rewards = self.beta * policy_chosen_logps + rejected_rewards = self.beta * policy_rejected_logps + + return ( + losses, + chosen_rewards, + rejected_rewards, + torch.mean(ratio), + torch.mean(log_odds), + ) + + +class TorchLMHeadORPO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + ): + super().__init__() + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.orpo_loss = HFORPOLoss(ignore_index=ignore_index, beta=beta).get_batch_loss_metrics + + def forward(self, x, y, nll_target=None): + return self.orpo_loss(self.lin.weight, x, y, self.lin.bias, nll_target=nll_target) + + +class LigerLMHeadORPO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + ): + super().__init__() + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.orpo_loss = LigerFusedLinearORPOLoss(ignore_index=ignore_index, beta=beta) + + def forward(self, x, y, nll_target=None): + return self.orpo_loss(self.lin.weight, x, y, self.lin.bias, nll_target=nll_target) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("ignore_index, beta", [(-100, 0.1), (42, 0.2)]) +def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta): + # reset torch compiler cache + torch.compiler.reset() + + B = 2 * B # orpo loss requires B to be even + torch_lm_head_orpo = TorchLMHeadORPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + liger_lm_head_orpo = LigerLMHeadORPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + ) + + torch_lm_head_orpo.lin.weight.data = liger_lm_head_orpo.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + + if bias: + torch_lm_head_orpo.lin.bias.data = liger_lm_head_orpo.lin.bias.data = torch.randn(V, device=device, dtype=dtype) + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + nll_target = torch.randint(0, V, (B, T), device=device, dtype=torch.long) + + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + loss1, aggregated_aux_outputs1 = torch_lm_head_orpo(input1, target, nll_target) + loss2, aggregated_aux_outputs2 = liger_lm_head_orpo(input2, target, nll_target) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) + + for i in range(len(aggregated_aux_outputs1)): + assert_verbose_allclose( + aggregated_aux_outputs1[i], + aggregated_aux_outputs2[i], + atol=atol, + rtol=rtol, + ) + + loss1.backward() + loss2.backward() + + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose( + torch_lm_head_orpo.lin.weight.grad, + liger_lm_head_orpo.lin.weight.grad, + atol=atol, + rtol=rtol, + ) + if bias: + assert_verbose_allclose( + torch_lm_head_orpo.lin.bias.grad, + liger_lm_head_orpo.lin.bias.grad, + atol=atol, + rtol=rtol, + ) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (2, 2, 8, 8), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias): + # reset torch compiler cache + torch.compiler.reset() + + B = 2 * B + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + + _weight = torch.randn(V, H, device=device, dtype=dtype) + weight1 = _weight.detach().clone().requires_grad_(True) + weight2 = _weight.detach().clone().requires_grad_(True) + + _bias = torch.randn(V, device=device, dtype=dtype) if bias else None + bias1 = _bias.detach().clone().requires_grad_(True) if bias else None + bias2 = _bias.detach().clone().requires_grad_(True) if bias else None + + loss1, _ = LigerFusedLinearORPOFunction.apply(input1, weight1, target, bias1) + loss2, _ = liger_fused_linear_orpo(input2, weight2, target, bias2) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + loss1.backward() + loss2.backward() + + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose(weight1.grad, weight2.grad, atol=atol, rtol=rtol) + if bias: + assert_verbose_allclose(bias1.grad, bias2.grad, atol=atol, rtol=rtol) diff --git a/test/chunked_loss/test_simpo_loss.py b/test/chunked_loss/test_simpo_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..4a6f019598a9dcfd0d78d060d2e3a8196aaa0a9e --- /dev/null +++ b/test/chunked_loss/test_simpo_loss.py @@ -0,0 +1,215 @@ +import pytest +import torch + +from liger_kernel.chunked_loss import LigerFusedLinearSimPOLoss +from liger_kernel.chunked_loss.functional import liger_fused_linear_simpo +from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction +from liger_kernel.utils import infer_device +from test.chunked_loss.test_cpo_loss import TorchLMHeadCPO +from test.utils import assert_verbose_allclose +from test.utils import set_seed + +device = infer_device() + +# set random seed globally +set_seed() + + +class LigerLMHeadSimPO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + alpha: float = 1.0, + label_smoothing: float = 0.0, + gamma: float = 0.5, + ): + super().__init__() + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.simpo_loss = LigerFusedLinearSimPOLoss( + ignore_index=ignore_index, + beta=beta, + alpha=alpha, + gamma=gamma, + label_smoothing=label_smoothing, + ) + + def forward(self, x, y): + return self.simpo_loss(self.lin.weight, x, y, self.lin.bias) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-3, 5e-3), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("ignore_index, beta, gamma", [(-100, 0.1, 0.5), (42, 0.2, 0.85)]) +@pytest.mark.parametrize("label_smoothing", [0.0, 0.1]) +def test_correctness( + B, + T, + H, + V, + scalar, + dtype, + atol, + rtol, + bias, + ignore_index, + beta, + gamma, + label_smoothing, +): + B = 2 * B # SimPO loss requires B to be even + + torch_lm_head_simpo = TorchLMHeadCPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + loss_type="simpo", + label_smoothing=label_smoothing, + simpo_gamma=gamma, + ) + liger_lm_head_simpo = LigerLMHeadSimPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ignore_index=ignore_index, + beta=beta, + label_smoothing=label_smoothing, + gamma=gamma, + ) + + torch_lm_head_simpo.lin.weight.data = liger_lm_head_simpo.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + + if bias: + torch_lm_head_simpo.lin.bias.data = liger_lm_head_simpo.lin.bias.data = torch.randn( + V, device=device, dtype=dtype + ) + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + loss1, aggregated_aux_outputs1 = torch_lm_head_simpo(input1, target) + loss2, aggregated_aux_outputs2 = liger_lm_head_simpo(input2, target) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) + + for i in range(len(aggregated_aux_outputs1)): + assert_verbose_allclose( + aggregated_aux_outputs1[i], + aggregated_aux_outputs2[i], + atol=atol, + rtol=rtol, + ) + + loss1.backward() + loss2.backward() + + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose( + torch_lm_head_simpo.lin.weight.grad, + liger_lm_head_simpo.lin.weight.grad, + atol=atol, + rtol=rtol, + ) + if bias: + assert_verbose_allclose( + torch_lm_head_simpo.lin.bias.grad, + liger_lm_head_simpo.lin.bias.grad, + atol=atol, + rtol=rtol, + ) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (2, 2, 8, 8), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias): + B = 2 * B + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + + _weight = torch.randn(V, H, device=device, dtype=dtype) + weight1 = _weight.detach().clone().requires_grad_(True) + weight2 = _weight.detach().clone().requires_grad_(True) + + _bias = torch.randn(V, device=device, dtype=dtype) if bias else None + bias1 = _bias.detach().clone().requires_grad_(True) if bias else None + bias2 = _bias.detach().clone().requires_grad_(True) if bias else None + + loss1, aggregated_aux_outputs1 = LigerFusedLinearSimPOFunction.apply(input1, weight1, target, bias1) + loss2, aggregated_aux_outputs2 = liger_fused_linear_simpo(input2, weight2, target, bias2) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + loss1.backward() + loss2.backward() + + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose(weight1.grad, weight2.grad, atol=atol, rtol=rtol) + if bias: + assert_verbose_allclose(bias1.grad, bias2.grad, atol=atol, rtol=rtol) diff --git a/test/conftest.py b/test/conftest.py new file mode 100755 index 0000000000000000000000000000000000000000..3d36a0d2256f2ac058d191394a62fcc1668c9f28 --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,15 @@ +import pytest +import torch + +from liger_kernel.utils import is_npu_available + + +@pytest.fixture(autouse=True) +def clear_gpu_cache(): + yield + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif is_npu_available(): + torch.npu.empty_cache() + elif torch.xpu.is_available(): + torch.xpu.empty_cache() diff --git a/test/convergence/__init__.py b/test/convergence/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/test/convergence/bf16/__init__.py b/test/convergence/bf16/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/test/convergence/bf16/test_mini_models.py b/test/convergence/bf16/test_mini_models.py new file mode 100755 index 0000000000000000000000000000000000000000..98400b87041a7fb988ea79fef35d354d50247cff --- /dev/null +++ b/test/convergence/bf16/test_mini_models.py @@ -0,0 +1,2324 @@ +import os + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # Ensure deterministic behavior with CuBLAS + +import pytest +import torch +import transformers + +from datasets import load_from_disk +from packaging import version +from torch.utils.data import DataLoader +from transformers.models.gemma import GemmaConfig +from transformers.models.gemma import GemmaForCausalLM +from transformers.models.gemma2 import Gemma2Config +from transformers.models.gemma2 import Gemma2ForCausalLM +from transformers.models.llama import LlamaConfig +from transformers.models.llama import LlamaForCausalLM +from transformers.models.mistral import MistralConfig +from transformers.models.mistral import MistralForCausalLM +from transformers.models.mixtral import MixtralConfig +from transformers.models.mixtral import MixtralForCausalLM +from transformers.models.phi3 import Phi3Config +from transformers.models.phi3 import Phi3ForCausalLM +from transformers.models.qwen2 import Qwen2Config +from transformers.models.qwen2 import Qwen2ForCausalLM + +from liger_kernel.transformers import apply_liger_kernel_to_exaone4 +from liger_kernel.transformers import apply_liger_kernel_to_falcon_h1 +from liger_kernel.transformers import apply_liger_kernel_to_gemma +from liger_kernel.transformers import apply_liger_kernel_to_gemma2 +from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text +from liger_kernel.transformers import apply_liger_kernel_to_glm4 +from liger_kernel.transformers import apply_liger_kernel_to_glm4v +from liger_kernel.transformers import apply_liger_kernel_to_glm4v_moe +from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss +from liger_kernel.transformers import apply_liger_kernel_to_granite +from liger_kernel.transformers import apply_liger_kernel_to_hunyuan_v1_dense +from liger_kernel.transformers import apply_liger_kernel_to_hunyuan_v1_moe +from liger_kernel.transformers import apply_liger_kernel_to_internvl +from liger_kernel.transformers import apply_liger_kernel_to_llama +from liger_kernel.transformers import apply_liger_kernel_to_llama4 +from liger_kernel.transformers import apply_liger_kernel_to_llava +from liger_kernel.transformers import apply_liger_kernel_to_mistral +from liger_kernel.transformers import apply_liger_kernel_to_mixtral +from liger_kernel.transformers import apply_liger_kernel_to_mllama +from liger_kernel.transformers import apply_liger_kernel_to_olmo2 +from liger_kernel.transformers import apply_liger_kernel_to_olmo3 +from liger_kernel.transformers import apply_liger_kernel_to_phi3 +from liger_kernel.transformers import apply_liger_kernel_to_qwen2 +from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl +from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl +from liger_kernel.transformers import apply_liger_kernel_to_qwen3 +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5 +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5_moe +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_moe +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_next +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl_moe +from liger_kernel.transformers import apply_liger_kernel_to_smollm3 +from liger_kernel.utils import infer_device +from test.utils import DEFAULT_DATASET_PATH +from test.utils import MiniModelConfig +from test.utils import assert_verbose_allclose +from test.utils import get_logprobs +from test.utils import get_topk +from test.utils import require_deterministic +from test.utils import revert_liger_kernel_to_exaone4 +from test.utils import revert_liger_kernel_to_falcon_h1 +from test.utils import revert_liger_kernel_to_gemma +from test.utils import revert_liger_kernel_to_gemma2 +from test.utils import revert_liger_kernel_to_gemma3_text +from test.utils import revert_liger_kernel_to_glm4 +from test.utils import revert_liger_kernel_to_glm4v +from test.utils import revert_liger_kernel_to_glm4v_moe +from test.utils import revert_liger_kernel_to_gpt_oss +from test.utils import revert_liger_kernel_to_granite +from test.utils import revert_liger_kernel_to_hunyuan_v1 +from test.utils import revert_liger_kernel_to_hunyuan_v1_moe +from test.utils import revert_liger_kernel_to_internvl +from test.utils import revert_liger_kernel_to_llama +from test.utils import revert_liger_kernel_to_llama4 +from test.utils import revert_liger_kernel_to_llava +from test.utils import revert_liger_kernel_to_mistral +from test.utils import revert_liger_kernel_to_mixtral +from test.utils import revert_liger_kernel_to_mllama +from test.utils import revert_liger_kernel_to_olmo2 +from test.utils import revert_liger_kernel_to_olmo3 +from test.utils import revert_liger_kernel_to_phi3 +from test.utils import revert_liger_kernel_to_qwen2 +from test.utils import revert_liger_kernel_to_qwen2_5_vl +from test.utils import revert_liger_kernel_to_qwen2_vl +from test.utils import revert_liger_kernel_to_qwen3 +from test.utils import revert_liger_kernel_to_qwen3_5 +from test.utils import revert_liger_kernel_to_qwen3_5_moe +from test.utils import revert_liger_kernel_to_qwen3_moe +from test.utils import revert_liger_kernel_to_qwen3_next +from test.utils import revert_liger_kernel_to_qwen3_vl +from test.utils import revert_liger_kernel_to_qwen3_vl_moe +from test.utils import revert_liger_kernel_to_smollm3 +from test.utils import set_seed +from test.utils import simple_collate_fn +from test.utils import supports_bfloat16 + +IS_TRANSFORMERS_V5_OR_LATER = version.parse(transformers.__version__) >= version.parse("5.0.0") + +try: + from transformers.models.llama4.configuration_llama4 import Llama4TextConfig + from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM + + LLAMA4_AVAILABLE = True +except ImportError: + LLAMA4_AVAILABLE = False + +try: + # Mllama is only available in transformers>=4.45.0 + from transformers.models.mllama.configuration_mllama import MllamaTextConfig + from transformers.models.mllama.modeling_mllama import MllamaForCausalLM + + MLLAMA_AVAILABLE = True +except ImportError: + MLLAMA_AVAILABLE = False + +try: + # Qwen2-VL is only available in transformers>4.52.4 + import transformers + + from packaging import version + from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration + + QWEN2_VL_AVAILABLE = version.parse(transformers.__version__) >= version.parse("4.52.4") +except ImportError: + QWEN2_VL_AVAILABLE = False + +try: + # Qwen2.5-VL is only available in transformers>4.52.4 + import transformers + + from packaging import version + from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration + + QWEN2_5_VL_AVAILABLE = version.parse(transformers.__version__) >= version.parse("4.52.4") +except ImportError: + QWEN2_5_VL_AVAILABLE = False + + +try: + # Qwen2.5-VL is only available in transformers>=4.57.0 + import transformers + + from packaging import version + from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration + + QWEN3_VL_AVAILABLE = version.parse(transformers.__version__) >= version.parse("4.57.0") +except ImportError: + QWEN3_VL_AVAILABLE = False + + +try: + # Qwen3-VL-MoE is only available in transformers>=4.57.0 + import transformers + + from packaging import version + from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeConfig + from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeTextConfig + from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeVisionConfig + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration + + QWEN3_VL_MOE_AVAILABLE = version.parse(transformers.__version__) >= version.parse("4.57.0") +except ImportError: + QWEN3_VL_MOE_AVAILABLE = False + + +try: + from transformers.models.qwen3.configuration_qwen3 import Qwen3Config + from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM + from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig + from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeForCausalLM + + QWEN3_AVAILABLE = True +except ImportError: + QWEN3_AVAILABLE = False + +try: + from transformers.models.granite import GraniteConfig + from transformers.models.granite import GraniteForCausalLM + + GRANITE_AVAILABLE = True +except ImportError: + GRANITE_AVAILABLE = False + +try: + from transformers import CLIPVisionConfig + from transformers.models.llava.configuration_llava import LlavaConfig + from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration + + LLAVA_AVAILABLE = True +except ImportError: + LLAVA_AVAILABLE = False + +try: + # OLMO2 is only available in transformers>=4.47.0 + from transformers.models.olmo2.configuration_olmo2 import Olmo2Config + from transformers.models.olmo2.modeling_olmo2 import Olmo2ForCausalLM + + OLMO2_AVAILABLE = True +except ImportError: + OLMO2_AVAILABLE = False + +try: + # OLMO3 is only available in transformers>=4.57.0 + from transformers.models.olmo3.configuration_olmo3 import Olmo3Config + from transformers.models.olmo3.modeling_olmo3 import Olmo3ForCausalLM + + OLMO3_AVAILABLE = True +except ImportError: + OLMO3_AVAILABLE = False + +try: + # Glm4 is only available in transformers>=4.51.3 + from transformers.models.glm4.configuration_glm4 import Glm4Config + from transformers.models.glm4.modeling_glm4 import Glm4ForCausalLM + + GLM4_AVAILABLE = True +except ImportError: + GLM4_AVAILABLE = False + +try: + # Glm4v is only available in transformers>=4.51.3 + from transformers.models.glm4v.configuration_glm4v import Glm4vConfig + from transformers.models.glm4v.modeling_glm4v import Glm4vForConditionalGeneration + + GLM4V_AVAILABLE = True +except ImportError: + GLM4V_AVAILABLE = False + +try: + # Glm4v_moe is only available in transformers>=4.51.3 + from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig + from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeForConditionalGeneration + + GLM4V_MOE_AVAILABLE = True +except ImportError: + GLM4V_MOE_AVAILABLE = False + +try: + from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig + from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM + + GEMMA3_AVAILABLE = True +except ImportError: + GEMMA3_AVAILABLE = False + +try: + # Smollm3 is only available in transformers>=4.53.0 + from transformers.models.smollm3.configuration_smollm3 import SmolLM3Config + from transformers.models.smollm3.modeling_smollm3 import SmolLM3ForCausalLM + + SMOLLM3_AVAILABLE = True +except ImportError: + SMOLLM3_AVAILABLE = False + +try: + # InternVL is only available in transformers>=4.52.1 + from transformers.models.internvl.configuration_internvl import InternVLConfig + from transformers.models.internvl.modeling_internvl import InternVLForConditionalGeneration + + INTERNVL_AVAILABLE = True +except ImportError: + INTERNVL_AVAILABLE = False + +try: + # FalconH1 is only available in transformers>=4.53.0 + from transformers.models.falcon_h1.configuration_falcon_h1 import FalconH1Config + from transformers.models.falcon_h1.modeling_falcon_h1 import FalconH1ForCausalLM + + FALCONH1_AVAILABLE = True +except ImportError: + FALCONH1_AVAILABLE = False + +try: + # GPT-OSS is only available in transformers>=4.55.0 + from transformers.models.gpt_oss.configuration_gpt_oss import GptOssConfig + from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM + + GPT_OSS_AVAILABLE = True +except ImportError: + GPT_OSS_AVAILABLE = False + +try: + # Qwen3Next is only available in transformers>=4.57.0 + from transformers.models.qwen3_next.configuration_qwen3_next import Qwen3NextConfig + from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextForCausalLM + + QWEN3NEXT_AVAILABLE = True +except ImportError: + QWEN3NEXT_AVAILABLE = False + +try: + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeForCausalLM + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeTextConfig + + QWEN3_5_MOE_AVAILABLE = True +except ImportError: + QWEN3_5_MOE_AVAILABLE = False + +try: + from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM + + QWEN3_5_AVAILABLE = True +except ImportError: + QWEN3_5_AVAILABLE = False + +try: + from transformers.models.hunyuan_v1_dense.configuration_hunyuan_v1_dense import HunYuanDenseV1Config + from transformers.models.hunyuan_v1_dense.modeling_hunyuan_v1_dense import HunYuanDenseV1ForCausalLM + from transformers.models.hunyuan_v1_moe.configuration_hunyuan_v1_moe import HunYuanMoEV1Config + from transformers.models.hunyuan_v1_moe.modeling_hunyuan_v1_moe import HunYuanMoEV1ForCausalLM + + HUNYUAN_V1_AVAILABLE = True +except ImportError: + HUNYUAN_V1_AVAILABLE = False + +try: + from transformers.models.exaone4.configuration_exaone4 import Exaone4Config + from transformers.models.exaone4.modeling_exaone4 import Exaone4ForCausalLM + + EXAONE4_AVAILABLE = True +except ImportError: + EXAONE4_AVAILABLE = False + + +device = infer_device() + +MINI_MODEL_SETUPS = { + "mini_llama3": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_llama, + liger_kernel_patch_revert_func=revert_liger_kernel_to_llama, + model_class=LlamaForCausalLM, + mini_model_config=LlamaConfig( + attention_bias=False, + attention_dropout=0.0, + # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=8192, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + num_key_value_heads=2, # 8 + pretraining_tp=1, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ), + "mini_qwen2": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen2, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen2, + model_class=Qwen2ForCausalLM, + mini_model_config=Qwen2Config( + attention_dropout=0.0, + bos_token_id=1, # 151643 + eos_token_id=2, # 151643 + hidden_act="silu", + hidden_size=896, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=32768, # 131072 + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-6, + sliding_window=131072, + tie_word_embeddings=True, + use_cache=True, + vocab_size=32000, # 151936 + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ), + "mini_phi3": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_phi3, + liger_kernel_patch_revert_func=revert_liger_kernel_to_phi3, + model_class=Phi3ForCausalLM, + mini_model_config=Phi3Config( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, # 32000 + hidden_act="silu", + hidden_size=896, # 3072 + initializer_range=0.02, + intermediate_size=4864, # 8192 + max_position_embeddings=4096, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + num_key_value_heads=None, # defaults to num_attention_heads + rms_norm_eps=1e-5, + sliding_window=None, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32064, + attn_implementation="eager", + ), + ), + "mini_mistral": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_mistral, + liger_kernel_patch_revert_func=revert_liger_kernel_to_mistral, + model_class=MistralForCausalLM, + mini_model_config=MistralConfig( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=2048, + max_position_embeddings=32768, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-5, + sliding_window=4096, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, + attn_implementation="sdpa", + ), + ), + "mini_mixtral": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_mixtral, + liger_kernel_patch_revert_func=revert_liger_kernel_to_mixtral, + model_class=MixtralForCausalLM, + mini_model_config=MixtralConfig( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=512, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=32768, # 32768 + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-5, + sliding_window=4096, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, + attn_implementation="sdpa", + ), + ), + "mini_gemma1": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_gemma, + liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma, + model_class=GemmaForCausalLM, + mini_model_config=GemmaConfig( + vocab_size=32000, # 256000 + hidden_size=1024, # 3072 + intermediate_size=2048, # 24576 + num_hidden_layers=4, # 28 + num_attention_heads=4, # 16 + num_key_value_heads=4, # 16 + head_dim=256, + # gemma1 model config uses `hidden_act` and point it to gelu, + # https://huggingface.co/google/gemma-7b/blob/main/config.json#L10 + # but in reality it's ignored and HuggingFace will use tanh approximation: + # https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/gemma/modeling_gemma.py#L175 + hidden_act="gelu", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + pad_token_id=0, + # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + tie_word_embeddings=True, + attention_bias=False, + attention_dropout=0.0, + ), + ), + "mini_gemma1.1": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_gemma, + liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma, + model_class=GemmaForCausalLM, + mini_model_config=GemmaConfig( + vocab_size=32000, # 256000 + hidden_size=1024, # 3072 + intermediate_size=2048, # 24576 + num_hidden_layers=4, # 28 + num_attention_heads=4, # 16 + num_key_value_heads=4, # 16 + head_dim=256, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + pad_token_id=0, + # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + tie_word_embeddings=True, + attention_bias=False, + attention_dropout=0.0, + ), + ), + "mini_gemma2": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_gemma2, + liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma2, + model_class=Gemma2ForCausalLM, + mini_model_config=Gemma2Config( + vocab_size=32000, # 256000 + hidden_size=1024, # 3072 + intermediate_size=2048, # 24576 + num_hidden_layers=4, # 28 + num_attention_heads=4, # 16 + num_key_value_heads=4, # 16 + head_dim=256, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + pad_token_id=0, + # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + tie_word_embeddings=True, + attention_bias=False, + attention_dropout=0.0, + attn_implementation="eager", + ), + ), +} + +if LLAMA4_AVAILABLE: + MINI_MODEL_SETUPS["mini_llama4"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_llama4, + liger_kernel_patch_revert_func=revert_liger_kernel_to_llama4, + model_class=Llama4ForCausalLM, + mini_model_config=Llama4TextConfig( + bos_token_id=1, # None + eos_token_id=2, # 151329, 151336, 151338 + pad_token_id=2, # 151329 + partial_rotary_factor=1.0, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 6144 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, # 32768 + num_attention_heads=8, # 48 + num_hidden_layers=4, # 61 + num_key_value_heads=2, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 151552 + attention_bias=True, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + + +if QWEN3_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3, + model_class=Qwen3ForCausalLM, + mini_model_config=Qwen3Config( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=896, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=32768, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-6, + sliding_window=131072, + tie_word_embeddings=True, + use_cache=True, + vocab_size=32000, + attn_implementation="sdpa", + ), + ) + + MINI_MODEL_SETUPS["mini_qwen3_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_moe, + model_class=Qwen3MoeForCausalLM, + mini_model_config=Qwen3MoeConfig( + vocab_size=32000, # 151936 + hidden_size=896, + intermediate_size=4864, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + attention_bias=False, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + decoder_sparse_step=1, + moe_intermediate_size=768, + num_experts_per_tok=2, + num_experts=8, + norm_topk_prob=False, + output_router_logits=False, + router_aux_loss_coef=0.001, + mlp_only_layers=None, + ), + ) + +if GPT_OSS_AVAILABLE: + MINI_MODEL_SETUPS["mini_gpt_oss"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_gpt_oss, + liger_kernel_patch_revert_func=revert_liger_kernel_to_gpt_oss, + model_class=GptOssForCausalLM, + mini_model_config=GptOssConfig( + vocab_size=32000, # 201088 + hidden_size=896, + intermediate_size=896, # Same as hidden_size for GPT-OSS + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + head_dim=64, + hidden_act="silu", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=False, + attention_dropout=0.0, + num_local_experts=8, # Reduced from 32 for mini model + num_experts_per_tok=2, # Reduced from 4 for mini model + router_aux_loss_coef=0.9, + output_router_logits=False, + sliding_window=128, + layer_types=["sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(4)], + ), + ) + +if GEMMA3_AVAILABLE: + MINI_MODEL_SETUPS["mini_gemma3_text"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_gemma3_text, + liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma3_text, + model_class=Gemma3ForCausalLM, + mini_model_config=Gemma3TextConfig( + vocab_size=32000, # 262144 + hidden_size=1024, # 1152 + intermediate_size=2048, # 6912 + num_hidden_layers=4, # 26 + num_attention_heads=4, + num_key_value_heads=1, + head_dim=256, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=8192, # 32768 + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + pad_token_id=0, + bos_token_id=2, + eos_token_id=1, + tie_word_embeddings=True, + attention_bias=False, + attention_dropout=0.0, + attn_implementation="eager", + ), + ) + + +if MLLAMA_AVAILABLE: + MINI_MODEL_SETUPS["mini_mllama"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_mllama, + liger_kernel_patch_revert_func=revert_liger_kernel_to_mllama, + model_class=MllamaForCausalLM, + mini_model_config=MllamaTextConfig( + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + pad_token_id=2, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=131_072, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 40 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + attn_implementation="sdpa", # default value, pytorch native attention + rope_scaling=dict( + factor=8.0, + high_freq_factor=4.0, + low_freq_factor=1.0, + original_max_position_embeddings=8192, + rope_type="llama3", + rope_theta=500_000, + ) + if not IS_TRANSFORMERS_V5_OR_LATER + else None, + ), + ) + +if QWEN2_VL_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen2_vl"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen2_vl, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen2_vl, + model_class=Qwen2VLForConditionalGeneration, + mini_model_config=Qwen2VLConfig( + # In transformers v5, text-related parameters must be in text_config + text_config={ + "attention_dropout": 0.0, + # bos and eos set to match the Mistral-7B tokenizer used to create the test dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + "bos_token_id": 1, # 151643 + "eos_token_id": 2, # 151645 + "hidden_act": "silu", + "hidden_size": 1536, # 8192 + "initializer_range": 0.02, + "intermediate_size": 4864, # 29568 + "max_position_embeddings": 32768, + "max_window_layers": 4, # 80 + "num_attention_heads": 12, # 64 + "num_hidden_layers": 4, # 80 + "num_key_value_heads": 2, # 8 + "rms_norm_eps": 1e-6, # 1e-5 + **( + {"rope_parameters": {"mrope_section": [16, 24, 24]}} # (temporal, height, width) + if IS_TRANSFORMERS_V5_OR_LATER + else {"rope_scaling": {"type": "mrope", "mrope_section": [16, 24, 24]}} + ), + "sliding_window": 4096, + "tie_word_embeddings": False, + "use_cache": True, + "vocab_size": 32768, # 152064 # >32k, Mistral-7B tokenizer vocab size + "use_sliding_window": False, + }, + vision_start_token_id=32765, # vocab_size - 5 + vision_end_token_id=32766, # vocab_size - 4 + image_token_id=32768, # vocab_size - 2 + video_token_id=32769, # vocab_size - 1 + vision_config={ + "depth": 4, # 32 + "embed_dim": 1280, + "mlp_ratio": 4, + "num_heads": 16, + "in_chans": 3, + "hidden_size": 128, # 1536 + "patch_size": 14, + "spatial_merge_size": 2, + "spatial_patch_size": 14, + "temporal_patch_size": 2, + }, + ), + ) + +if QWEN2_5_VL_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen2_5_vl"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen2_5_vl, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen2_5_vl, + model_class=Qwen2_5_VLForConditionalGeneration, + mini_model_config=Qwen2_5_VLConfig( + # In transformers v5, text-related parameters must be in text_config + text_config={ + "attention_dropout": 0.0, + # bos and eos set to match the Mistral-7B tokenizer used to create the test dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + "bos_token_id": 1, # 151643 + "eos_token_id": 2, # 151645 + "hidden_act": "silu", + "hidden_size": 1536, # 8192 + "initializer_range": 0.02, + "intermediate_size": 4864, # 29568 + "max_position_embeddings": 32768, + "max_window_layers": 4, # 80 + "num_attention_heads": 12, # 64 + "num_hidden_layers": 4, # 80 + "num_key_value_heads": 2, # 8 + "rms_norm_eps": 1e-6, # 1e-5 + **( + {"rope_parameters": {"mrope_section": [16, 24, 24]}} # (temporal, height, width) + if IS_TRANSFORMERS_V5_OR_LATER + else {"rope_scaling": {"type": "mrope", "mrope_section": [16, 24, 24]}} + ), + "sliding_window": 4096, + "tie_word_embeddings": False, + "use_cache": True, + "vocab_size": 32768, # 152064 # >32k, Mistral-7B tokenizer vocab size + "use_sliding_window": False, + }, + vision_start_token_id=32765, # vocab_size - 5 + vision_end_token_id=32766, # vocab_size - 4 + image_token_id=32768, # vocab_size - 2 + video_token_id=32769, # vocab_size - 1 + vision_config={ + "depth": 4, # 32 + "hidden_act": "silu", + "hidden_size": 128, # 1280 + "intermediate_size": 256, # 3420 + "num_heads": 16, + "in_chans": 3, + "out_hidden_size": 128, # 3584 + "patch_size": 14, + "spatial_merge_size": 2, + "spatial_patch_size": 14, + "window_size": 112, + "fullatt_block_indexes": [7, 15, 23, 31], + "tokens_per_second": 2, + "temporal_patch_size": 2, + }, + ), + ) + + +if QWEN3_VL_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3_vl"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_vl, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_vl, + model_class=Qwen3VLForConditionalGeneration, + mini_model_config=Qwen3VLConfig( + bos_token_id=1, + eos_token_id=2, + vision_start_token_id=32765, + vision_end_token_id=32766, + image_token_id=32768, + video_token_id=32769, + tie_word_embeddings=False, + attn_implementation="sdpa", + text_config=dict( + attention_dropout=0.0, + hidden_act="silu", + hidden_size=1536, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=32768, + num_attention_heads=12, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-6, + use_cache=True, + vocab_size=32768, + pad_token_id=None, + rope_scaling=dict( + type="mrope", + mrope_section=[16, 24, 24], # (temporal, height, width) + ) + if not IS_TRANSFORMERS_V5_OR_LATER + else None, + ), + vision_config=dict( + depth=4, + hidden_size=128, + hidden_act="silu", + intermediate_size=256, + num_heads=8, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=128, + num_position_embeddings=256, + deepstack_visual_indexes=[], + initializer_range=0.02, + ), + ), + ) + +if QWEN3_VL_MOE_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3_vl_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_vl_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_vl_moe, + model_class=Qwen3VLMoeForConditionalGeneration, + mini_model_config=Qwen3VLMoeConfig( + bos_token_id=1, + eos_token_id=2, + vision_start_token_id=32765, + vision_end_token_id=32766, + image_token_id=32768, + video_token_id=32769, + tie_word_embeddings=False, + attn_implementation="sdpa", + text_config=Qwen3VLMoeTextConfig( + attention_dropout=0.0, + attention_bias=False, + hidden_act="silu", + hidden_size=1536, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=32768, + num_attention_heads=12, + num_hidden_layers=4, + num_key_value_heads=2, + head_dim=128, + rms_norm_eps=1e-6, + use_cache=True, + vocab_size=32768, + decoder_sparse_step=1, + moe_intermediate_size=3072, + num_experts_per_tok=2, + num_experts=4, + tie_word_embeddings=False, + mlp_only_layers=[], + pad_token_id=None, + rope_scaling=dict( + type="mrope", + mrope_section=[16, 24, 24], # (temporal, height, width) + ) + if not IS_TRANSFORMERS_V5_OR_LATER + else None, + ).to_dict(), + vision_config=Qwen3VLMoeVisionConfig( + depth=4, + hidden_size=128, + hidden_act="gelu_pytorch_tanh", + intermediate_size=256, + num_heads=8, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=128, + num_position_embeddings=256, + deepstack_visual_indexes=[1, 2, 3], + initializer_range=0.02, + ).to_dict(), + ), + ) + +if GRANITE_AVAILABLE: + MINI_MODEL_SETUPS["mini_granite3"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_granite, + liger_kernel_patch_revert_func=revert_liger_kernel_to_granite, + model_class=GraniteForCausalLM, + mini_model_config=GraniteConfig( + attention_bias=False, + attention_dropout=0.1, + # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=8192, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + num_key_value_heads=2, # 8 + pretraining_tp=1, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if LLAVA_AVAILABLE: + # https://huggingface.co/llava-hf/llava-1.5-7b-hf + MINI_MODEL_SETUPS["mini_llava"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_llava, + liger_kernel_patch_revert_func=revert_liger_kernel_to_llava, + model_class=LlavaForConditionalGeneration, + mini_model_config=LlavaConfig( + text_config=LlamaConfig( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=2048, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + pretraining_tp=1, + tie_word_embeddings=False, + use_cache=True, + max_position_embeddings=4096, # llava-1.5-7b-hf + rms_norm_eps=1e-05, # llava-1.5-7b-hf + vocab_size=32064, # llava-1.5-7b-hf + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + vision_config=CLIPVisionConfig( + hidden_size=1024, + image_size=336, + intermediate_size=2048, # 4096 + model_type="clip_vision_model", + num_attention_heads=4, # 16 + num_hidden_layers=4, # 24 + patch_size=14, + projection_dim=768, + vocab_size=32000, + ), + vocab_size=32064, + ignore_index=-100, + pad_token_id=4, + image_token_index=3, + projector_hidden_act="gelu", + vision_feature_layer=-2, + vision_feature_select_strategy="default", + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if OLMO2_AVAILABLE: + MINI_MODEL_SETUPS["mini_olmo2"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_olmo2, + liger_kernel_patch_revert_func=revert_liger_kernel_to_olmo2, + model_class=Olmo2ForCausalLM, + mini_model_config=Olmo2Config( + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + pad_token_id=2, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 40 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if OLMO3_AVAILABLE: + MINI_MODEL_SETUPS["mini_olmo3"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_olmo3, + liger_kernel_patch_revert_func=revert_liger_kernel_to_olmo3, + model_class=Olmo3ForCausalLM, + mini_model_config=Olmo3Config( + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + pad_token_id=2, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 40 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if GLM4_AVAILABLE: + MINI_MODEL_SETUPS["mini_glm4"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_glm4, + liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4, + model_class=Glm4ForCausalLM, + mini_model_config=Glm4Config( + bos_token_id=1, # None + eos_token_id=2, # 151329, 151336, 151338 + pad_token_id=2, # 151329 + partial_rotary_factor=0.5, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 6144 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, # 32768 + num_attention_heads=8, # 48 + num_hidden_layers=4, # 61 + num_key_value_heads=2, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 151552 + attention_bias=True, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if GLM4V_AVAILABLE: + MINI_MODEL_SETUPS["mini_glm4v"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_glm4v, + liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4v, + model_class=Glm4vForConditionalGeneration, + mini_model_config=Glm4vConfig( + bos_token_id=1, # None + eos_token_id=2, # 151329, 151336, 151338 + pad_token_id=2, # 151329 + image_token_id=151343, + video_token_id=151344, + image_start_token_id=151339, + image_end_token_id=151340, + video_start_token_id=151341, + video_end_token_id=151342, + partial_rotary_factor=0.5, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 6144 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, # 32768 + num_attention_heads=8, # 48 + num_hidden_layers=4, # 61 + num_key_value_heads=2, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 151552 + attention_bias=True, + attn_implementation="sdpa", # default value, pytorch native attention + text_config={ + "partial_rotary_factor": 0.5, + "hidden_act": "silu", + "hidden_size": 1024, + "intermediate_size": 2048, + "max_position_embeddings": 4096, + "num_attention_heads": 8, + "num_hidden_layers": 4, + "num_key_value_heads": 2, + "rms_norm_eps": 1e-5, + "vocab_size": 32000, + "attention_bias": True, + **( + {"rope_scaling": {"type": "default", "mrope_section": [8, 12, 12]}} + if not IS_TRANSFORMERS_V5_OR_LATER + else {} + ), + "pad_token_id": None, + }, + vision_config={ + "depth": 4, # 32 + "hidden_act": "silu", + "hidden_size": 128, # 1280 + "intermediate_size": 256, # 3420 + "num_heads": 16, + "in_chans": 3, + "out_hidden_size": 128, # 3584 + "patch_size": 14, + "spatial_merge_size": 2, + "temporal_patch_size": 2, + }, + ), + ) + +if GLM4V_MOE_AVAILABLE: + MINI_MODEL_SETUPS["mini_glm4v_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_glm4v_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4v_moe, + model_class=Glm4vMoeForConditionalGeneration, + mini_model_config=Glm4vMoeConfig( + bos_token_id=1, # None + eos_token_id=2, # 151329, 151336, 151338 + pad_token_id=2, # 151329 + image_token_id=151343, + video_token_id=151344, + image_start_token_id=151339, + image_end_token_id=151340, + video_start_token_id=151341, + video_end_token_id=151342, + partial_rotary_factor=0.5, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 6144 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, # 32768 + num_attention_heads=8, # 48 + num_hidden_layers=4, # 61 + num_key_value_heads=2, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 151552 + attention_bias=True, + attn_implementation="sdpa", # default value, pytorch native attention + text_config={ + "partial_rotary_factor": 0.5, + "hidden_act": "silu", + "hidden_size": 1024, + "intermediate_size": 2048, + "max_position_embeddings": 4096, + "num_attention_heads": 8, + "num_hidden_layers": 4, + "num_key_value_heads": 2, + "rms_norm_eps": 1e-5, + "vocab_size": 32000, + "attention_bias": True, + "attention_dropout": 0.0, + "moe_intermediate_size": 1408, + "num_experts_per_tok": 2, + "n_shared_experts": 1, + "n_routed_experts": 8, + "routed_scaling_factor": 1.0, + "n_group": 1, + "topk_group": 1, + "first_k_dense_replace": 1, + "norm_topk_prob": True, + **( + {"rope_scaling": {"type": "default", "mrope_section": [8, 12, 12]}} + if not IS_TRANSFORMERS_V5_OR_LATER + else {} + ), + }, + vision_config={ + "depth": 4, # 32 + "hidden_act": "silu", + "hidden_size": 128, # 1280 + "intermediate_size": 256, # 3420 + "num_heads": 16, + "in_chans": 3, + "out_hidden_size": 128, # 3584 + "patch_size": 14, + "spatial_merge_size": 2, + "temporal_patch_size": 2, + }, + ), + ) + +if SMOLLM3_AVAILABLE: + MINI_MODEL_SETUPS["mini_smollm3"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_smollm3, + liger_kernel_patch_revert_func=revert_liger_kernel_to_smollm3, + model_class=SmolLM3ForCausalLM, + mini_model_config=SmolLM3Config( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + pad_token_id=2, # 128000 + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=8192, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + num_key_value_heads=2, # 8 + pretraining_tp=1, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if INTERNVL_AVAILABLE: + MINI_MODEL_SETUPS["mini_internvl"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_internvl, + liger_kernel_patch_revert_func=revert_liger_kernel_to_internvl, + model_class=InternVLForConditionalGeneration, + mini_model_config=InternVLConfig( + text_config=Qwen2Config( + rms_norm_eps=1e-5, + hidden_size=256, # 1024 + intermediate_size=1024, # 4096 + hidden_act="silu", + num_hidden_layers=4, # 24 + num_attention_heads=4, # 16 + num_key_value_heads=2, # 16 + max_position_embeddings=4096, # 8192 + vocab_size=32000, # 151936 + bos_token_id=1, + eos_token_id=2, + pad_token_id=2, + tie_word_embeddings=False, + ), + vision_config={ + "hidden_size": 256, # 1024 + "intermediate_size": 1024, # 4096 + "num_hidden_layers": 4, # 24 + "num_attention_heads": 4, # 16 + }, + image_token_id=10, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if FALCONH1_AVAILABLE: + MINI_MODEL_SETUPS["mini_falcon_h1"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_falcon_h1, + liger_kernel_patch_revert_func=revert_liger_kernel_to_falcon_h1, + model_class=FalconH1ForCausalLM, + mini_model_config=FalconH1Config( + model_type="falcon_h1", + vocab_size=32000, + hidden_size=256, # 4096 + num_hidden_layers=4, # 24 + num_attention_heads=4, # 32 + num_key_value_heads=2, # 8 + intermediate_size=1024, # 11008 + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + mamba_d_ssm=128, # 1024 + mamba_n_heads=16, # 128 + mamba_d_state=32, # 245 + mamba_d_conv=2, # 4 + attn_implementation="eager", + ), + ) + +if QWEN3NEXT_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3_next"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_next, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_next, + model_class=Qwen3NextForCausalLM, + mini_model_config=Qwen3NextConfig( # Copypaste Qwen3MoeConfig + vocab_size=32000, + hidden_size=896, + intermediate_size=4864, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + attention_bias=False, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + decoder_sparse_step=1, + moe_intermediate_size=768, + num_experts_per_tok=2, + num_experts=8, + norm_topk_prob=False, + output_router_logits=False, + router_aux_loss_coef=0.001, + # config.dtype must be set if fla installed since there's a bug in the original code (No torch.get_current_dtype()) + dtype=torch.bfloat16, + ), + ) + +if QWEN3_5_MOE_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3_5_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_5_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_5_moe, + model_class=Qwen3_5MoeForCausalLM, + mini_model_config=Qwen3_5MoeTextConfig( + vocab_size=32000, + hidden_size=896, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + attention_bias=False, + attention_dropout=0.0, + head_dim=128, + linear_conv_kernel_dim=4, + linear_key_head_dim=64, + linear_value_head_dim=64, + linear_num_key_heads=8, + linear_num_value_heads=8, + moe_intermediate_size=768, + shared_expert_intermediate_size=768, + num_experts_per_tok=2, + num_experts=8, + output_router_logits=False, + router_aux_loss_coef=0.001, + # config.dtype must be set if fla installed since there's a bug in the original code (No torch.get_current_dtype()) + dtype=torch.bfloat16, + ), + ) + +if QWEN3_5_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3_5"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_5, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_5, + model_class=Qwen3_5ForCausalLM, + mini_model_config=Qwen3_5TextConfig( + vocab_size=32000, + hidden_size=896, + intermediate_size=4864, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + attention_bias=False, + attention_dropout=0.0, + head_dim=128, + linear_conv_kernel_dim=4, + linear_key_head_dim=64, + linear_value_head_dim=64, + linear_num_key_heads=8, + linear_num_value_heads=8, + layer_types=["linear_attention", "linear_attention", "linear_attention", "full_attention"], + dtype=torch.bfloat16, + ), + ) + +if HUNYUAN_V1_AVAILABLE: + MINI_MODEL_SETUPS["mini_hunyuan_v1"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_hunyuan_v1_dense, + liger_kernel_patch_revert_func=revert_liger_kernel_to_hunyuan_v1, + model_class=HunYuanDenseV1ForCausalLM, + mini_model_config=HunYuanDenseV1Config( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + num_hidden_layers=4, + hidden_size=896, + intermediate_size=4864, + num_attention_heads=8, + head_dim=112, + rms_norm_eps=1e-6, + tie_word_embeddings=True, + max_position_embeddings=32768, + initializer_range=0.02, + norm_eps=1e-6, + num_key_value_heads=2, + partial_rotary_factor=1.0, + vocab_size=32000, + use_cache=True, + attn_implementation="sdpa", + ), + ) + + MINI_MODEL_SETUPS["mini_hunyuan_v1_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_hunyuan_v1_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_hunyuan_v1_moe, + model_class=HunYuanMoEV1ForCausalLM, + mini_model_config=HunYuanMoEV1Config( + vocab_size=32000, + hidden_size=128, + intermediate_size=512, + head_dim=16, + num_hidden_layers=2, + num_attention_heads=8, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + eod_token_id=3, + sep_token_id=4, + tie_word_embeddings=False, + attention_bias=False, + attention_dropout=0.0, + num_experts=2, + moe_topk=1, + attn_implementation="sdpa", + ), + ) + +if EXAONE4_AVAILABLE: + MINI_MODEL_SETUPS["mini_exaone4"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_exaone4, + liger_kernel_patch_revert_func=revert_liger_kernel_to_exaone4, + model_class=Exaone4ForCausalLM, + mini_model_config=Exaone4Config( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=896, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=32768, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-5, + tie_word_embeddings=True, + use_cache=True, + vocab_size=32000, + attn_implementation="sdpa", + pad_token_id=None, + ), + ) + + +def create_model(model_name="mini_llama4"): + """ + Create a mini version model + The commented values are the original values + """ + model_config = MINI_MODEL_SETUPS[model_name].mini_model_config + model_class = MINI_MODEL_SETUPS[model_name].model_class + return model_class(model_config) + + +@require_deterministic +def run_mini_model( + model_name="mini_llama4", + num_steps=100, + dtype=torch.bfloat16, + lr=1e-5, + with_liger=False, +): + # If we move it to the beginning of test_mini_model, the two runs are initialized with different weights. + # This is due to RNG (Random Number Generator). The formula of RNG progression is x_(n+1) = (a * x_n + c) % m + # Everytime RNG is used, like randomly initialzing weight, the RNG progresses to the next state. + # Therefore, we have to reset RNG before we create the model to ensure the weight initialization started from the same RNG state. + + set_seed(42) + + revert_kwargs = {"model_config": MINI_MODEL_SETUPS[model_name]} + if "mllama" in model_name: + revert_kwargs["model_type"] = "causal_lm" + + if with_liger is True: + kwargs = { + "rope": True, + "rms_norm": True, + } + + if "glm4" in model_name or "qwen3_next" in model_name or "qwen3_5" in model_name: + kwargs["rope"] = False + + model_supports_layer_norm = "qwen2_vl" in model_name + if model_supports_layer_norm: + kwargs["layer_norm"] = True + + if "gemma" in model_name: + kwargs["geglu"] = True + else: + kwargs["swiglu"] = True + + if "llava" in model_name: + apply_liger_kernel_to_llama(**kwargs) + + # fused_linear_cross_entropy is not supported in mini_granite3 + kwargs["fused_linear_cross_entropy"] = True if model_name != "mini_granite3" else False + kwargs["cross_entropy"] = False + + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) + else: + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs) + + model = create_model(model_name).to(dtype).to(device) + + train_dataset = load_from_disk(DEFAULT_DATASET_PATH) + loader = DataLoader(train_dataset, batch_size=16, shuffle=False, collate_fn=simple_collate_fn) + loader_iter = iter(loader) + optimizer = torch.optim.AdamW(model.parameters(), lr=lr) + + loss_list = [] + + for i in range(num_steps): + batch = next(loader_iter).to(model.device) + optimizer.zero_grad() + output = model(**batch, accum_dtype=torch.float32) + output.loss.backward() + optimizer.step() + print(f"Step {i}, Loss: {output.loss.item()}") + loss_list.append(output.loss.item()) + + model.eval() + eval_batch = next(loader_iter).to(model.device) + if with_liger: + eval_batch["skip_logits"] = False + with torch.no_grad(): + eval_output = model(**eval_batch) + print(f"Eval Loss: {eval_output.loss.item()}") + loss_list.append(eval_output.loss.item()) + topk_logprobs = get_topk(get_logprobs(eval_output.logits)) + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs) + return { + "loss": loss_list, + "topk_logprobs": topk_logprobs.values, + "model": model, + } + + +@pytest.mark.parametrize( + "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logprobs_atol, logprobs_rtol, param_atol, param_rtol", + [ + pytest.param( + "mini_llama4", # llama4 requires slightly larger tolerances to pass this test after bug fix to llama4 in transformers v5.0.0 + 32, + 1e-5, + torch.bfloat16, + 5e-2, + 4e-1, + 1e-1, + 1e-1, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not LLAMA4_AVAILABLE, + reason="Llama not available in this version of transformers", + ), + pytest.mark.skipif( + not IS_TRANSFORMERS_V5_OR_LATER, + reason="The `attention_bias` configuration of Llama4 is not set in Transformers v4", + ), + ], + ), + pytest.param( + "mini_llama3", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + ), + pytest.param( + "mini_llava", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 5e-2, + 1e-1, + 1e-1, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not LLAVA_AVAILABLE, + reason="LLaVa not available in this version of transformers", + ), + pytest.mark.skipif( + version.parse(transformers.__version__) < version.parse("4.52.0"), + reason="LLaVa doesn't materialize logits in transformers<=4.52.0 so we can't test it", + ), + ], + ), + pytest.param( + "mini_granite3", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 5e-2, + 1e-1, # 1e-1 + 1e-2, # 1e-2 + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not GRANITE_AVAILABLE, + reason="Granite not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_mllama", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not MLLAMA_AVAILABLE, + reason="Mllama not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_qwen2", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + ), + pytest.param( + "mini_qwen3", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not QWEN3_AVAILABLE, + reason="Qwen3 not available in this version of transformers", + ), + ], + ), + # TODO(tcc): Investigate qwen3_moe on different machines. + # The loss diverges on ci test (A10G), but it never diverges on my local machine (3080). + # Qwen3_moe can pass float32 tests. (mecoli1219): diverges on h100 + pytest.param( + "mini_qwen3_moe", + 32, + 1e-5, + torch.bfloat16, + 5e-2, + 2e-1, + 1e-1, # 1e-1 + 1e-1, # 1e-2 + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not QWEN3_AVAILABLE, + reason="Qwen3 not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_gpt_oss", + 32, + 1e-5, + torch.bfloat16, + 5e-2, + 5e-2, + 1e-1, + 1e-1, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not GPT_OSS_AVAILABLE, + reason="GPT-OSS not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_qwen2_vl", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 5e-2, + 1e-1, # 1e-1 + 1e-2, # 1e-2 + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not QWEN2_VL_AVAILABLE, + reason="Qwen2-VL not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_qwen2_5_vl", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 5e-2, + 1e-1, # 1e-1 + 1e-2, # 1e-2 + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not QWEN2_5_VL_AVAILABLE, + reason="Qwen2.5-VL not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_qwen3_vl", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 5e-2, + 1e-1, # 1e-1 + 1e-2, # 1e-2 + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not QWEN3_VL_AVAILABLE, + reason="Qwen3-VL not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_qwen3_vl_moe", + 32, + 1e-5, + torch.bfloat16, + 5e-2, + 1e-1, + 1e-1, + 5e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not QWEN3_VL_MOE_AVAILABLE, + reason="Qwen3-VL-MoE not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_phi3", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + ), + pytest.param( + "mini_mistral", + 32, + 1e-5, + torch.bfloat16, + 5e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + ], + ), + pytest.param( + "mini_olmo2", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not OLMO2_AVAILABLE, + reason="OLMO2 not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_olmo3", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not OLMO3_AVAILABLE, + reason="OLMO3 not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_glm4", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not GLM4_AVAILABLE, + reason="Glm4 not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_glm4v", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 2e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not GLM4V_AVAILABLE, + reason="Glm4v not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_glm4v_moe", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 4e-1, # rms_norm patch needs higher tolerance in bf16 + 1e-1, + 5e-1, # rms_norm patch needs higher tolerance in bf16 + 2e-1, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not GLM4V_MOE_AVAILABLE, + reason="Glm4v_moe not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_smollm3", + 32, + 1e-4, + torch.bfloat16, + 1e-3, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not SMOLLM3_AVAILABLE, + reason="Smollm3 not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_internvl", + 32, + 1e-4, + torch.bfloat16, + 1e-3, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not INTERNVL_AVAILABLE, + reason="InternVL not available in this version of transformers", + ), + ], + ), + # TODO: mixtral is flaky so disable the test for now + # pytest.param( + # "mini_mixtral", + # 32, + # 1e-4, + # torch.bfloat16, + # 1e-3, + # 1e-2, + # 1e-1, + # 1e-2, + # 1e-1, + # 1e-2, + # marks=pytest.mark.skipif( + # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + # ), + # ), + # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way) + pytest.param( + "mini_gemma1", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + ), + pytest.param( + "mini_gemma1.1", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + ), + # TODO: Gemma2 test for bf16 is not passing within the tolerance range, might be casting issue, need to investigate + # pytest.param( + # "mini_gemma2", + # 32, + # 1e-4, + # torch.bfloat16, + # 1e-3, + # 1e-2, + # 1e-1, + # 1e-2, + # 1e-2, + # 1e-2, + # marks=pytest.mark.skipif( + # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + # ), + # ), + pytest.param( + "mini_gemma3_text", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not GEMMA3_AVAILABLE, + reason="Gemma3 not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_falcon_h1", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not FALCONH1_AVAILABLE, + reason="FalconH1 not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_qwen3_next", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 1e-2, + 1e-1, + 1e-1, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not QWEN3NEXT_AVAILABLE, + reason="Qwen3Next not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_qwen3_5_moe", + 32, + 1e-5, + torch.bfloat16, + 5e-2, + 2e-1, + 1e-1, + 1e-1, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not QWEN3_5_MOE_AVAILABLE, + reason="Qwen3_5Moe not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_qwen3_5", + 32, + 1e-5, + torch.bfloat16, + 5e-2, + 2e-1, + 1e-1, + 1e-1, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not QWEN3_5_AVAILABLE, + reason="Qwen3_5 not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_hunyuan_v1", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not HUNYUAN_V1_AVAILABLE, + reason="Hunyuan_v1 not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_hunyuan_v1_moe", + 32, + 1e-5, + torch.bfloat16, + 5e-2, + 5e-2, + 1e-1, # 1e-1 + 1e-1, # 1e-2 + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not HUNYUAN_V1_AVAILABLE, + reason="Hunyuan_v1_moe not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_exaone4", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not EXAONE4_AVAILABLE, + reason="EXAONE4 not available in this version of transformers", + ), + ], + ), + ], +) +def test_mini_model( + model_name, + num_steps, + lr, + dtype, + loss_atol, + loss_rtol, + logprobs_atol, + logprobs_rtol, + param_atol, + param_rtol, +): + # Non-liger models should be initialized and tested first to avoid the module being overridden + + expected_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr) + + actual_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True) + + # Compare every step of the loss + assert_verbose_allclose( + torch.tensor([expected_output["loss"]]), + torch.tensor([actual_output["loss"]]), + atol=loss_atol, + rtol=loss_rtol, + extra_info="[Loss]", + ) + + # Compare the topk logprobs from evaluation step + if expected_output["topk_logprobs"] is not None and actual_output["topk_logprobs"] is not None: + assert_verbose_allclose( + expected_output["topk_logprobs"], + actual_output["topk_logprobs"], + atol=logprobs_atol, + rtol=logprobs_rtol, + extra_info="[Top k logprobs]", + ) + + # Compare the params from the last step + # Iterate over the model's parameters and compare them + for expected_param, actual_param in zip( + expected_output["model"].named_parameters(), + actual_output["model"].named_parameters(), + ): + assert_verbose_allclose( + expected_param[1], actual_param[1], atol=param_atol, rtol=param_rtol, extra_info="[Model parameters]" + ) diff --git a/test/convergence/bf16/test_mini_models_multimodal.py b/test/convergence/bf16/test_mini_models_multimodal.py new file mode 100755 index 0000000000000000000000000000000000000000..5495c8bbbbc51b4fdaa06d179ed1e9e3b8c46ed0 --- /dev/null +++ b/test/convergence/bf16/test_mini_models_multimodal.py @@ -0,0 +1,1830 @@ +import functools +import os + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # Ensure deterministic behavior with CuBLAS +import pytest +import torch +import transformers + +from datasets import load_dataset +from packaging import version +from torch.utils.data import DataLoader +from transformers import PreTrainedTokenizerFast +from transformers.models.siglip.configuration_siglip import SiglipVisionConfig + +from liger_kernel.transformers import apply_liger_kernel_to_gemma3 +from liger_kernel.transformers import apply_liger_kernel_to_internvl +from liger_kernel.transformers import apply_liger_kernel_to_llama4 +from liger_kernel.transformers import apply_liger_kernel_to_llava +from liger_kernel.transformers import apply_liger_kernel_to_mllama +from liger_kernel.transformers import apply_liger_kernel_to_paligemma +from liger_kernel.transformers import apply_liger_kernel_to_pixtral +from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl +from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5 +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl_moe +from liger_kernel.transformers import apply_liger_kernel_to_smolvlm +from liger_kernel.utils import infer_device +from test.utils import FAKE_CONFIGS_PATH +from test.utils import UNTOKENIZED_DATASET_PATH +from test.utils import MiniModelConfig +from test.utils import assert_verbose_allclose +from test.utils import get_logprobs +from test.utils import get_topk +from test.utils import is_torchvision_available +from test.utils import load_image_processing_config +from test.utils import load_processor_config +from test.utils import load_tokenizer_config +from test.utils import multimodal_collate_fn +from test.utils import require_deterministic +from test.utils import revert_liger_kernel_to_gemma3 +from test.utils import revert_liger_kernel_to_internvl +from test.utils import revert_liger_kernel_to_llama4 +from test.utils import revert_liger_kernel_to_llava +from test.utils import revert_liger_kernel_to_mllama +from test.utils import revert_liger_kernel_to_Paligemma +from test.utils import revert_liger_kernel_to_pixtral +from test.utils import revert_liger_kernel_to_qwen2_5_vl +from test.utils import revert_liger_kernel_to_qwen2_vl +from test.utils import revert_liger_kernel_to_qwen3_5 +from test.utils import revert_liger_kernel_to_qwen3_vl +from test.utils import revert_liger_kernel_to_qwen3_vl_moe +from test.utils import revert_liger_kernel_to_smolvlm2 +from test.utils import set_seed +from test.utils import supports_bfloat16 +from test.utils import train_bpe_tokenizer + +IS_TRANSFORMERS_V5_OR_LATER = version.parse(transformers.__version__) >= version.parse("5.0.0") + +if IS_TRANSFORMERS_V5_OR_LATER: + from transformers.models.gemma.tokenization_gemma import GemmaTokenizer +else: + from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast as GemmaTokenizer + +try: + # Qwen2-VL is only available in transformers>=4.52.4 + import transformers + + from packaging import version + + if IS_TRANSFORMERS_V5_OR_LATER: + from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer + else: + from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast as Qwen2Tokenizer + from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig + from transformers.models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration + from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLProcessor + from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor + + QWEN2_VL_AVAILABLE = version.parse(transformers.__version__) >= version.parse("4.52.4") +except ImportError: + QWEN2_VL_AVAILABLE = False + +try: + # Qwen2.5-VL is only available in transformers>4.52.4 + import transformers + + from packaging import version + + if IS_TRANSFORMERS_V5_OR_LATER: + from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer + else: + from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast as Qwen2Tokenizer + from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration + from transformers.models.qwen2_5_vl.processing_qwen2_5_vl import Qwen2_5_VLProcessor + from transformers.models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor + from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor + + QWEN2_5_VL_AVAILABLE = version.parse(transformers.__version__) >= version.parse("4.52.4") +except ImportError: + QWEN2_5_VL_AVAILABLE = False + +try: + if IS_TRANSFORMERS_V5_OR_LATER: + from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer + else: + from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast as Qwen2Tokenizer + from transformers.models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor + from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig + from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLTextConfig + from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLVisionConfig + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration + from transformers.models.qwen3_vl.processing_qwen3_vl import Qwen3VLProcessor + from transformers.models.qwen3_vl.video_processing_qwen3_vl import Qwen3VLVideoProcessor + + QWEN3_VL_AVAILABLE = True +except ImportError: + QWEN3_VL_AVAILABLE = False + +try: + from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeConfig + from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeTextConfig + from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeVisionConfig + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration + + QWEN3_VL_MOE_AVAILABLE = True +except ImportError: + QWEN3_VL_MOE_AVAILABLE = False + +try: + from transformers.models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor + from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5Config + from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig + from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5VisionConfig + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForConditionalGeneration + from transformers.models.qwen3_vl.processing_qwen3_vl import Qwen3VLProcessor + from transformers.models.qwen3_vl.video_processing_qwen3_vl import Qwen3VLVideoProcessor + + QWEN3_5_AVAILABLE = True +except ImportError: + QWEN3_5_AVAILABLE = False + +try: + # Mllama is only available in transformers>=4.45.0 + from transformers.models.mllama.configuration_mllama import MllamaConfig + from transformers.models.mllama.configuration_mllama import MllamaTextConfig + from transformers.models.mllama.configuration_mllama import MllamaVisionConfig + from transformers.models.mllama.image_processing_mllama import MllamaImageProcessor + from transformers.models.mllama.modeling_mllama import MllamaForConditionalGeneration + from transformers.models.mllama.processing_mllama import MllamaProcessor + + MLLAMA_AVAILABLE = True +except ImportError: + MLLAMA_AVAILABLE = False + +try: + from transformers import CLIPImageProcessor + from transformers import CLIPVisionConfig + from transformers import LlamaConfig + from transformers.models.llava.configuration_llava import LlavaConfig + from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration + from transformers.models.llava.processing_llava import LlavaProcessor + + from liger_kernel.transformers import apply_liger_kernel_to_llama + + LLAVA_AVAILABLE = True +except ImportError: + LLAVA_AVAILABLE = False + +try: + import transformers + + from packaging import version + from transformers.models.gemma.configuration_gemma import GemmaConfig + from transformers.models.gemma2.configuration_gemma2 import Gemma2Config + from transformers.models.paligemma.configuration_paligemma import PaliGemmaConfig + from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration + from transformers.models.paligemma.processing_paligemma import PaliGemmaProcessor + from transformers.models.siglip.configuration_siglip import SiglipVisionConfig + from transformers.models.siglip.image_processing_siglip import SiglipImageProcessor + + PALIGEMMA_AVAILABLE = True +except ImportError: + PALIGEMMA_AVAILABLE = False + + +try: + # Gemma3 is only available in transformers>=4.50.0 + from transformers.models.gemma3.configuration_gemma3 import Gemma3Config + from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig + from transformers.models.gemma3.image_processing_gemma3 import Gemma3ImageProcessor + from transformers.models.gemma3.modeling_gemma3 import Gemma3ForConditionalGeneration + from transformers.models.gemma3.processing_gemma3 import Gemma3Processor + + GEMMA3_AVAILABLE = True +except ImportError: + GEMMA3_AVAILABLE = False + +try: + from transformers.models.llama4.configuration_llama4 import Llama4Config + from transformers.models.llama4.configuration_llama4 import Llama4TextConfig + from transformers.models.llama4.configuration_llama4 import Llama4VisionConfig + from transformers.models.llama4.image_processing_llama4_fast import Llama4ImageProcessorFast + from transformers.models.llama4.modeling_llama4 import Llama4ForConditionalGeneration + from transformers.models.llama4.processing_llama4 import Llama4Processor + + LLAMA4_AVAILABLE = True + +except ImportError: + LLAMA4_AVAILABLE = False + +try: + from transformers.models.got_ocr2.image_processing_got_ocr2_fast import GotOcr2ImageProcessorFast + from transformers.models.internvl.configuration_internvl import InternVLConfig + from transformers.models.internvl.modeling_internvl import InternVLForConditionalGeneration + from transformers.models.internvl.processing_internvl import InternVLProcessor + from transformers.models.internvl.video_processing_internvl import InternVLVideoProcessor + from transformers.models.qwen2.configuration_qwen2 import Qwen2Config + + # Input fp32 with bf16 CNN-based models in InternVL is only working in transformers>=4.56.0 + INTERNVL_AVAILABLE = version.parse(transformers.__version__) >= version.parse("4.56.0") +except ImportError: + INTERNVL_AVAILABLE = False + +try: + # SmolVLM2 is only available in transformers>=4.50.0 + from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer + from transformers.models.smolvlm.configuration_smolvlm import SmolVLMConfig + from transformers.models.smolvlm.image_processing_smolvlm import SmolVLMImageProcessor + from transformers.models.smolvlm.modeling_smolvlm import SmolVLMForConditionalGeneration + from transformers.models.smolvlm.processing_smolvlm import SmolVLMProcessor + from transformers.models.smolvlm.video_processing_smolvlm import SmolVLMVideoProcessor + + SMOLVLM2_AVAILABLE = True +except ImportError: + SMOLVLM2_AVAILABLE = False + +try: + from transformers.models.pixtral.configuration_pixtral import PixtralVisionConfig + from transformers.models.pixtral.modeling_pixtral import PixtralVisionModel + + PIXTRAL_AVAILABLE = True +except ImportError: + PIXTRAL_AVAILABLE = False + +try: + from num2words import num2words # noqa: F401 + + NUM2WORDS_AVAILABLE = True +except ImportError: + NUM2WORDS_AVAILABLE = False + + +device = infer_device() + +torch.use_deterministic_algorithms(True) + +# Only setting torch.use_deterministic_algorithms(True) throws the following error: +# RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or `at::Context::setDeterministicAlgorithms(true)`, +# but this operation is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this case, you must set an +# environment variable before running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. For more information, +# go to https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + +TEST_IMAGE_DIM = 64 + +MINI_MODEL_SETUPS = {} + +if LLAMA4_AVAILABLE: + MINI_MODEL_SETUPS["mini_llama4"] = MiniModelConfig( + liger_kernel_patch_func=functools.partial(apply_liger_kernel_to_llama4, fused_linear_cross_entropy=False), + liger_kernel_patch_revert_func=revert_liger_kernel_to_llama4, + model_class=Llama4ForConditionalGeneration, + mini_model_config=Llama4Config( + image_token_index=8, + vision_config=Llama4VisionConfig( + attn_implementation_autoset=True, + attention_dropout=0.0, + hidden_act="gelu", + hidden_size=512, # 1280 + image_size=560, # 560 + initializer_range=0.02, + intermediate_layers_indices=[2], # [3, 7, 15, etc...] + intermediate_size=2048, # 5120 + max_num_tiles=1, # 4 + norm_eps=1e-5, + num_attention_heads=4, # 16 + num_channels=3, + num_global_layers=2, # 8 + num_hidden_layers=8, # 32 + patch_size=280, # 14 + supported_aspect_ratios=[[1, 1]], # [[1, 1], [1, 2], etc... ] + vision_output_dim=4096, # 7680 + ), + text_config=Llama4TextConfig( + bos_token_id=0, + eos_token_id=0, + pad_token_id=0, + cross_attention_layers=[2], # [3, 8, 13, 18, etc...] + dropout=0, + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=131_072, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 40 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + ), + attn_implementation="sdpa", + pad_token_id=None, + ), + ) + +if MLLAMA_AVAILABLE: + MINI_MODEL_SETUPS["mini_mllama"] = MiniModelConfig( + liger_kernel_patch_func=functools.partial(apply_liger_kernel_to_mllama, fused_linear_cross_entropy=False), + liger_kernel_patch_revert_func=revert_liger_kernel_to_mllama, + model_class=MllamaForConditionalGeneration, + mini_model_config=MllamaConfig( + vision_config=MllamaVisionConfig( + hidden_act="gelu", + hidden_size=512, # 1280 + image_size=560, # 560 + initializer_range=0.02, + intermediate_layers_indices=[2], # [3, 7, 15, etc...] + intermediate_size=2048, # 5120 + max_num_tiles=1, # 4 + norm_eps=1e-5, + num_attention_heads=4, # 16 + num_channels=3, + num_global_layers=2, # 8 + num_hidden_layers=8, # 32 + patch_size=140, # 14 + supported_aspect_ratios=[[1, 1]], # [[1, 1], [1, 2], etc... ] + vision_output_dim=1024, # 7680 + ), + text_config=MllamaTextConfig( + bos_token_id=0, + eos_token_id=0, + pad_token_id=0, + cross_attention_layers=[2], # [3, 8, 13, 18, etc...] + dropout=0, + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=131_072, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 40 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-5, + rope_scaling=dict( + factor=8.0, + high_freq_factor=4.0, + low_freq_factor=1.0, + original_max_position_embeddings=8192, + rope_type="llama3", + ) + if not IS_TRANSFORMERS_V5_OR_LATER + else None, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + ), + image_token_index=1, # NOTE: outside the vocab size + attn_implementation="sdpa", + ), + ) + +if PALIGEMMA_AVAILABLE: + MINI_MODEL_SETUPS["mini_paligemma"] = MiniModelConfig( + liger_kernel_patch_func=functools.partial(apply_liger_kernel_to_paligemma, fused_linear_cross_entropy=False), + liger_kernel_patch_revert_func=revert_liger_kernel_to_Paligemma, + model_class=PaliGemmaForConditionalGeneration, + mini_model_config=PaliGemmaConfig( + vision_config=SiglipVisionConfig( + attention_dropout=0.0, + hidden_act="gelu_pytorch_tanh", + hidden_size=1152, + image_size=224, + intermediate_size=2048, # 4304 + layer_norm_eps=1e-06, + num_attention_heads=4, # 16 + num_channels=3, + num_hidden_layers=4, # 27 + num_image_tokens=256, + num_positions=256, + patch_size=14, + projection_dim=1024, # 2304 + ), + text_config=GemmaConfig( + vocab_size=32000, # 256000 + hidden_size=1024, # 3072 + intermediate_size=2048, # 24576 + num_hidden_layers=4, # 28 + num_attention_heads=4, # 16 + num_key_value_heads=4, # 16 + head_dim=256, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + pad_token_id=0, + # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + tie_word_embeddings=True, + attention_bias=False, + attention_dropout=0.0, + ), + image_token_index=4, # NOTE: outside the vocab size + attn_implementation="eager", + vocab_size=32000, + projection_dim=1024, + ), + ) + MINI_MODEL_SETUPS["mini_paligemma2"] = MiniModelConfig( + liger_kernel_patch_func=functools.partial(apply_liger_kernel_to_paligemma, fused_linear_cross_entropy=False), + liger_kernel_patch_revert_func=revert_liger_kernel_to_Paligemma, + model_class=PaliGemmaForConditionalGeneration, + mini_model_config=PaliGemmaConfig( + vision_config=SiglipVisionConfig( + attention_dropout=0.0, + hidden_act="gelu_pytorch_tanh", + hidden_size=1152, + image_size=224, + intermediate_size=2048, # 4304 + layer_norm_eps=1e-06, + num_attention_heads=4, # 16 + num_channels=3, + num_hidden_layers=4, # 27 + num_image_tokens=256, + num_positions=256, + patch_size=14, + projection_dim=1024, # 2304 + ), + text_config=Gemma2Config( + vocab_size=32000, # 256000 + hidden_size=1024, # 3072 + intermediate_size=2048, # 24576 + num_hidden_layers=4, # 28 + num_attention_heads=4, # 16 + num_key_value_heads=4, # 16 + head_dim=256, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + pad_token_id=0, + # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + tie_word_embeddings=True, + attention_bias=False, + attention_dropout=0.0, + ), + image_token_index=4, # NOTE: outside the vocab size + attn_implementation="eager", + vocab_size=32000, + projection_dim=1024, + ), + ) + +if GEMMA3_AVAILABLE: + MINI_MODEL_SETUPS["mini_gemma3"] = MiniModelConfig( + liger_kernel_patch_func=functools.partial(apply_liger_kernel_to_gemma3, fused_linear_cross_entropy=False), + liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma3, + model_class=Gemma3ForConditionalGeneration, + mini_model_config=Gemma3Config( + vision_config=SiglipVisionConfig( + attention_dropout=0.0, + hidden_act="gelu_pytorch_tanh", + hidden_size=1152, + image_size=224, + intermediate_size=2048, # 4304 + layer_norm_eps=1e-06, + num_attention_heads=4, # 16 + num_channels=3, + num_hidden_layers=4, # 27 + num_image_tokens=256, + num_positions=256, + patch_size=14, + ).to_dict(), + text_config=Gemma3TextConfig( + vocab_size=32000, # 256000 + hidden_size=1024, # 3072 + intermediate_size=2048, # 24576 + num_hidden_layers=4, # 28 + num_attention_heads=4, # 16 + num_key_value_heads=4, # 16 + head_dim=256, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + tie_word_embeddings=True, + attention_bias=False, + attention_dropout=0.0, + ).to_dict(), + image_token_index=5, # NOTE: outside the vocab size + boi_token_index=4, + eoi_token_index=6, + attn_implementation="eager", + ), + ) + +if QWEN2_VL_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen2_vl"] = MiniModelConfig( + liger_kernel_patch_func=functools.partial(apply_liger_kernel_to_qwen2_vl, fused_linear_cross_entropy=False), + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen2_vl, + model_class=Qwen2VLForConditionalGeneration, + mini_model_config=Qwen2VLConfig( + attention_dropout=0.0, + # Token Ids and vocab size must match those in the tokenizer/processor + # test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json + bos_token_id=0, + eos_token_id=0, + vision_start_token_id=1, + vision_end_token_id=2, + vision_token_id=3, + image_token_id=4, + video_token_id=5, + hidden_act="silu", + hidden_size=1024, # 8192 + initializer_range=0.02, + intermediate_size=1024, # 29568 + max_position_embeddings=32768, + max_window_layers=4, # 80 + num_attention_heads=8, # 64 + num_hidden_layers=4, # 80 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-6, # 1e-5 + **( + dict(rope_parameters=dict(mrope_section=[16, 24, 24])) # (temporal, height, width) + if IS_TRANSFORMERS_V5_OR_LATER + else dict(rope_scaling=dict(type="mrope", mrope_section=[16, 24, 24])) + ), + sliding_window=4096, + tie_word_embeddings=True, + use_cache=False, # True + vocab_size=32000, # 152064, + use_sliding_window=False, + vision_config={ + "depth": 4, # 32 + "embed_dim": 128, # 1280 + "mlp_ratio": 1, + "num_heads": 8, # 16 + "in_chans": 3, + "hidden_size": 1024, # 1536 + }, + attn_implementation="sdpa", + ), + ) + +if LLAVA_AVAILABLE: + # https://huggingface.co/llava-hf/llava-1.5-7b-hf + MINI_MODEL_SETUPS["mini_llava"] = MiniModelConfig( + liger_kernel_patch_func=functools.partial(apply_liger_kernel_to_llava, fused_linear_cross_entropy=False), + liger_kernel_patch_revert_func=revert_liger_kernel_to_llava, + model_class=LlavaForConditionalGeneration, + mini_model_config=LlavaConfig( + text_config=LlamaConfig( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=2048, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + pretraining_tp=1, + tie_word_embeddings=False, + use_cache=True, + max_position_embeddings=4096, # llava-1.5-7b-hf + rms_norm_eps=1e-05, # llava-1.5-7b-hf + vocab_size=32064, # llava-1.5-7b-hf + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + vision_config=CLIPVisionConfig( + hidden_size=1024, + image_size=336, + intermediate_size=2048, # 4096 + model_type="clip_vision_model", + num_attention_heads=4, # 16 + num_hidden_layers=4, # 24 + patch_size=14, + projection_dim=768, + vocab_size=32000, + ), + vocab_size=32064, + ignore_index=-100, + pad_token_id=4, + image_token_index=3, + projector_hidden_act="gelu", + vision_feature_layer=-2, + vision_feature_select_strategy="default", + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if INTERNVL_AVAILABLE: + MINI_MODEL_SETUPS["mini_internvl"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_internvl, + liger_kernel_patch_revert_func=revert_liger_kernel_to_internvl, + model_class=InternVLForConditionalGeneration, + mini_model_config=InternVLConfig( + text_config=Qwen2Config( + rms_norm_eps=1e-5, + hidden_size=256, # 1024 + intermediate_size=1024, # 4096 + hidden_act="silu", + num_hidden_layers=4, # 24 + num_attention_heads=4, # 16 + num_key_value_heads=2, # 16 + max_position_embeddings=4096, # 8192 + vocab_size=32000, # 151936 + bos_token_id=1, + eos_token_id=2, + pad_token_id=2, + tie_word_embeddings=False, + ), + vision_config={ + "hidden_size": 256, # 1024 + "intermediate_size": 1024, # 4096 + "num_hidden_layers": 4, # 24 + "num_attention_heads": 4, # 16 + }, + image_token_id=24, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if SMOLVLM2_AVAILABLE: + MINI_MODEL_SETUPS["mini_smolvlm2"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_smolvlm, + liger_kernel_patch_revert_func=revert_liger_kernel_to_smolvlm2, + model_class=SmolVLMForConditionalGeneration, + mini_model_config=SmolVLMConfig( + text_config=LlamaConfig( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + pad_token_id=2, + hidden_act="silu", + hidden_size=576, # 576 for 256M model + initializer_range=0.041666666666666664, + intermediate_size=1536, # 1536 for 256M model + max_position_embeddings=8192, + num_attention_heads=9, # 9 for 256M model + num_hidden_layers=4, # 30 -> reduced to 4 for testing + num_key_value_heads=3, # 3 for 256M model + rms_norm_eps=1e-5, + tie_word_embeddings=False, + vocab_size=49280, + ), + vision_config={ + "hidden_size": 768, + "intermediate_size": 3072, + "num_hidden_layers": 4, # 12 -> reduced to 4 for testing + "num_attention_heads": 12, + "image_size": 512, + "patch_size": 16, + }, + image_token_id=49190, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if QWEN2_5_VL_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen2_5_vl"] = MiniModelConfig( + liger_kernel_patch_func=functools.partial(apply_liger_kernel_to_qwen2_5_vl, fused_linear_cross_entropy=False), + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen2_5_vl, + model_class=Qwen2_5_VLForConditionalGeneration, + mini_model_config=Qwen2_5_VLConfig( + attention_dropout=0.0, + # Token Ids and vocab size must match those in the tokenizer/processor + # test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json + bos_token_id=0, + eos_token_id=0, + vision_start_token_id=1, + vision_end_token_id=2, + vision_token_id=3, + image_token_id=4, + video_token_id=5, + hidden_act="silu", + hidden_size=1024, # 8192 + initializer_range=0.02, + intermediate_size=1024, # 29568 + max_position_embeddings=32768, + max_window_layers=4, # 80 + num_attention_heads=8, # 64 + num_hidden_layers=4, # 80 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-6, # 1e-5 + **( + dict(rope_parameters=dict(mrope_section=[16, 24, 24])) # (temporal, height, width) + if IS_TRANSFORMERS_V5_OR_LATER + else dict(rope_scaling=dict(type="mrope", mrope_section=[16, 24, 24])) + ), + sliding_window=4096, + tie_word_embeddings=True, + use_cache=False, # True + vocab_size=32000, # 152064, + use_sliding_window=False, + vision_config={ + "depth": 4, # 32 + "hidden_size": 128, # 1280 + "num_heads": 16, + "in_chans": 3, + "out_hidden_size": 1024, + }, + attn_implementation="sdpa", + ), + ) + +if QWEN3_VL_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3_vl"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_vl, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_vl, + model_class=Qwen3VLForConditionalGeneration, + mini_model_config=Qwen3VLConfig( + attn_implementation="sdpa", + image_token_id=4, + video_token_id=5, + vision_start_token_id=1, + vision_end_token_id=2, + tie_word_embeddings=True, + vision_config=Qwen3VLVisionConfig( + depth=4, + hidden_size=256, + hidden_act="gelu_pytorch_tanh", + intermediate_size=512, + num_heads=4, + in_channels=3, + patch_size=16, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=512, + num_position_embeddings=256, + deepstack_visual_indexes=[1, 2, 3], + initializer_range=0.02, + ).to_dict(), + text_config=Qwen3VLTextConfig( + vocab_size=32000, + hidden_size=512, + intermediate_size=2048, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + head_dim=64, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=False, + tie_word_embeddings=True, + rope_scaling=dict( + type="mrope", + mrope_section=[16, 24, 24], # (temporal, height, width) + ) + if not IS_TRANSFORMERS_V5_OR_LATER + else None, + attention_dropout=0.0, + attention_bias=False, + ).to_dict(), + ), + ) + +if QWEN3_VL_MOE_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3_vl_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_vl_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_vl_moe, + model_class=Qwen3VLMoeForConditionalGeneration, + mini_model_config=Qwen3VLMoeConfig( + attn_implementation="sdpa", + image_token_id=4, + video_token_id=5, + vision_start_token_id=1, + vision_end_token_id=2, + tie_word_embeddings=True, + vision_config=Qwen3VLMoeVisionConfig( + depth=4, + hidden_size=256, + hidden_act="gelu_pytorch_tanh", + intermediate_size=512, + num_heads=4, + in_channels=3, + patch_size=16, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=512, + num_position_embeddings=256, + deepstack_visual_indexes=[1, 2, 3], + initializer_range=0.02, + ).to_dict(), + text_config=Qwen3VLMoeTextConfig( + vocab_size=32000, + hidden_size=512, + intermediate_size=2048, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + head_dim=64, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=False, + tie_word_embeddings=True, + rope_scaling=dict( + type="mrope", + mrope_section=[16, 24, 24], # (temporal, height, width) + ) + if not IS_TRANSFORMERS_V5_OR_LATER + else None, + attention_dropout=0.0, + attention_bias=False, + decoder_sparse_step=1, + moe_intermediate_size=1024, + num_experts_per_tok=2, + num_experts=4, + mlp_only_layers=[], + pad_token_id=None, + ).to_dict(), + ), + ) + +if QWEN3_5_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3_5"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_5, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_5, + model_class=Qwen3_5ForConditionalGeneration, + mini_model_config=Qwen3_5Config( + attn_implementation="sdpa", + image_token_id=4, + video_token_id=5, + vision_start_token_id=1, + vision_end_token_id=2, + tie_word_embeddings=True, + vision_config=Qwen3_5VisionConfig( + depth=4, + hidden_size=256, + hidden_act="gelu_pytorch_tanh", + intermediate_size=512, + num_heads=4, + in_channels=3, + patch_size=16, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=512, + num_position_embeddings=256, + deepstack_visual_indexes=[1, 2, 3], + initializer_range=0.02, + ).to_dict(), + text_config=Qwen3_5TextConfig( + vocab_size=32000, + hidden_size=512, + intermediate_size=2048, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + head_dim=64, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=False, + tie_word_embeddings=True, + rope_scaling=dict( + type="mrope", + mrope_section=[16, 24, 24], # (temporal, height, width) + ) + if not IS_TRANSFORMERS_V5_OR_LATER + else None, + attention_dropout=0.0, + attention_bias=False, + decoder_sparse_step=1, + moe_intermediate_size=1024, + num_experts_per_tok=2, + num_experts=4, + mlp_only_layers=[], + pad_token_id=None, + ).to_dict(), + ), + ) + + +if PIXTRAL_AVAILABLE: + MINI_MODEL_SETUPS["mini_pixtral"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_pixtral, + liger_kernel_patch_revert_func=revert_liger_kernel_to_pixtral, + model_class=PixtralVisionModel, + mini_model_config=PixtralVisionConfig( + hidden_size=1024, + intermediate_size=2048, + num_hidden_layers=4, + num_attention_heads=8, + num_channels=3, + image_size=256, + patch_size=16, + hidden_act="silu", + attention_dropout=0.0, + rope_theta=10000.0, + initializer_range=0.02, + ), + ) + + +def create_processor(model_name: str): + if model_name == "mini_qwen2_vl": + tokenizer_config = load_tokenizer_config( + os.path.join(FAKE_CONFIGS_PATH, "Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json") + ) + tokenizer_base = train_bpe_tokenizer( + [ + token.content + for key, token in sorted( + tokenizer_config["added_tokens_decoder"].items(), + key=lambda x: int(x[0]), + ) + ] + ) + qwen_tokenizer = Qwen2Tokenizer(tokenizer_object=tokenizer_base, **tokenizer_config) + image_processor = Qwen2VLImageProcessor() + video_processor = Qwen2VLVideoProcessor() + return Qwen2VLProcessor( + image_processor=image_processor, + video_processor=video_processor, + tokenizer=qwen_tokenizer, + ) + + elif model_name == "mini_qwen2_5_vl": + tokenizer_config = load_tokenizer_config( + os.path.join(FAKE_CONFIGS_PATH, "Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json") + ) + tokenizer_base = train_bpe_tokenizer( + [ + token.content + for key, token in sorted( + tokenizer_config["added_tokens_decoder"].items(), + key=lambda x: int(x[0]), + ) + ] + ) + qwen_tokenizer = Qwen2Tokenizer(tokenizer_object=tokenizer_base, **tokenizer_config) + image_processor = Qwen2VLImageProcessor() + video_processor = Qwen2VLVideoProcessor() + return Qwen2_5_VLProcessor( + image_processor=image_processor, + video_processor=video_processor, + tokenizer=qwen_tokenizer, + ) + + elif model_name in ("mini_qwen3_vl", "mini_qwen3_vl_moe", "mini_qwen3_5"): + tokenizer_config = load_tokenizer_config( + os.path.join(FAKE_CONFIGS_PATH, "Qwen/Qwen3-VL-4B-Instruct/tokenizer_config.json") + ) + tokenizer_base = train_bpe_tokenizer( + [ + token.content + for key, token in sorted( + tokenizer_config["added_tokens_decoder"].items(), + key=lambda x: int(x[0]), + ) + ] + ) + qwen_tokenizer = Qwen2Tokenizer(tokenizer_object=tokenizer_base, **tokenizer_config) + image_processor = Qwen2VLImageProcessor(patch_size=16, temporal_patch_size=2, merge_size=2) + video_processor = Qwen3VLVideoProcessor() + return Qwen3VLProcessor( + image_processor=image_processor, + video_processor=video_processor, + tokenizer=qwen_tokenizer, + ) + + elif model_name == "mini_llava": + tokenizer_config = load_tokenizer_config( + os.path.join( + FAKE_CONFIGS_PATH, + "Llava/llava-1.5-7b-hf/tokenizer_config.json", + ) + ) + image_processor_config = load_image_processing_config( + os.path.join( + FAKE_CONFIGS_PATH, + "Llava/llava-1.5-7b-hf/preprocessor_config.json", + ) + ) + processor_config = load_processor_config( + os.path.join( + FAKE_CONFIGS_PATH, + "Llava/llava-1.5-7b-hf/processor_config.json", + ) + ) + tokenizer_base = train_bpe_tokenizer( + [ + token.content + for key, token in sorted( + tokenizer_config["added_tokens_decoder"].items(), + key=lambda x: int(x[0]), + ) + ] + ) + + fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer_base, **tokenizer_config) + fast_tokenizer.model_input_names = ["input_ids", "attention_mask"] + image_processor = CLIPImageProcessor(**image_processor_config) + + return LlavaProcessor(**processor_config, image_processor=image_processor, tokenizer=fast_tokenizer) + + elif model_name == "mini_internvl": + tokenizer_config = load_tokenizer_config( + os.path.join(FAKE_CONFIGS_PATH, "OpenGVLab/InternVL3-1B-hf/tokenizer_config.json") + ) + tokenizer_base = train_bpe_tokenizer( + [ + token.content + for key, token in sorted( + tokenizer_config["added_tokens_decoder"].items(), + key=lambda x: int(x[0]), + ) + ] + ) + qwen_tokenizer = Qwen2Tokenizer(tokenizer_object=tokenizer_base, **tokenizer_config) + image_processor = GotOcr2ImageProcessorFast( + crop_to_patches=False, min_patches=1, max_patches=12, size={"height": 448, "width": 448} + ) + video_processor = InternVLVideoProcessor() + + # Return proper InternVL processor + return InternVLProcessor( + image_processor=image_processor, tokenizer=qwen_tokenizer, video_processor=video_processor + ) + + elif model_name == "mini_smolvlm2": + tokenizer_config = load_tokenizer_config( + os.path.join(FAKE_CONFIGS_PATH, "HuggingFaceTB/SmolVLM2-256M-Video-Instruct/tokenizer_config.json") + ) + tokenizer_base = train_bpe_tokenizer( + [ + token.content + for key, token in sorted( + tokenizer_config["added_tokens_decoder"].items(), + key=lambda x: int(x[0]), + ) + ] + ) + gpt2_tokenizer = GPT2Tokenizer(tokenizer_object=tokenizer_base, **tokenizer_config) + image_processor = SmolVLMImageProcessor(size={"longest_edge": 512}) + video_processor = SmolVLMVideoProcessor() + + # Return proper SmolVLM processor + return SmolVLMProcessor( + image_processor=image_processor, tokenizer=gpt2_tokenizer, video_processor=video_processor + ) + + elif model_name.startswith("mini_llama4"): + tokenizer_config = load_tokenizer_config( + os.path.join( + FAKE_CONFIGS_PATH, + "meta-llama/Llama-4-Scout-17B-16E-Instruct/tokenizer_config.json", + ) + ) + tokenizer_base = train_bpe_tokenizer( + [ + token.content + for key, token in sorted( + tokenizer_config["added_tokens_decoder"].items(), + key=lambda x: int(x[0]), + ) + ] + ) + fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer_base, **tokenizer_config) + image_processor = Llama4ImageProcessorFast(size={"height": 560, "width": 560}) + return Llama4Processor( + image_processor=image_processor, + tokenizer=fast_tokenizer, + fake_image_token="<|image|>", + image_token="<|image|>", + ) + elif model_name == "mini_mllama": + tokenizer_config = load_tokenizer_config( + os.path.join( + FAKE_CONFIGS_PATH, + "meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json", + ) + ) + tokenizer_base = train_bpe_tokenizer( + [ + token.content + for key, token in sorted( + tokenizer_config["added_tokens_decoder"].items(), + key=lambda x: int(x[0]), + ) + ] + ) + fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer_base, **tokenizer_config) + image_processor = MllamaImageProcessor(size={"height": 560, "width": 560}) + return MllamaProcessor(image_processor=image_processor, tokenizer=fast_tokenizer) + + elif model_name.startswith("mini_paligemma"): + tokenizer_config = load_tokenizer_config( + os.path.join( + FAKE_CONFIGS_PATH, + "Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json", + ) + ) + tokenizer_base = train_bpe_tokenizer( + [ + token.content + for key, token in sorted( + tokenizer_config["added_tokens_decoder"].items(), + key=lambda x: int(x[0]), + ) + ] + ) + + fast_tokenizer = GemmaTokenizer(tokenizer_object=tokenizer_base, **tokenizer_config) + image_processor = SiglipImageProcessor(size={"height": 224, "width": 224}, image_seq_length=256) + return PaliGemmaProcessor(image_processor=image_processor, tokenizer=fast_tokenizer) + + elif model_name.startswith("mini_gemma3"): + tokenizer_config = load_tokenizer_config( + os.path.join( + FAKE_CONFIGS_PATH, + "Google/Gemma3/gemma-3-4b-it/tokenizer_config.json", + ) + ) + tokenizer_base = train_bpe_tokenizer( + [ + token.content + for key, token in sorted( + tokenizer_config["added_tokens_decoder"].items(), + key=lambda x: int(x[0]), + ) + ] + ) + fast_tokenizer = GemmaTokenizer(tokenizer_object=tokenizer_base, **tokenizer_config) + image_processor = Gemma3ImageProcessor() + return Gemma3Processor(image_processor=image_processor, tokenizer=fast_tokenizer) + + else: + raise ValueError(f"Processor not available for model {model_name}") + + +def create_multimodal_dataset(model_name: str): + processor = create_processor(model_name) + + def generate_procedural_image(example, index): + """Generate an image with a single row of white pixels at the index specified""" + image = torch.zeros(3, TEST_IMAGE_DIM, TEST_IMAGE_DIM) + image[:, index % TEST_IMAGE_DIM, :] = 255 + example["image"] = image + return example + + def apply_chat_template(example): + """ + Under the hood, this inserts the correct image placeholder token into the text. + More or less this conversation format is used by HF's mllms. The fact that it is + formatting as for IFT is not in-and-of-itself important here. + """ + conversation = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "Describe this image."}, + ], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": example["text"]}], + }, + ] + example["text"] = processor.tokenizer.apply_chat_template(conversation, tokenize=False) + return example + + def preprocess_function(examples): + """Tokenize text, preprocess images, and generate other relevant inputs for the model.""" + if model_name == "mini_llama4": + # Process images and text separately to avoid complex token replacement, this helped setting lower tolerance than processing them together. + image_inputs = processor.image_processor(images=examples["image"], return_tensors="pt") + text_inputs = processor.tokenizer( + examples["text"], + padding="max_length", + truncation=True, + max_length=1024, + return_tensors="pt", + ) + return {**text_inputs, **image_inputs} + else: + # For other models, use the normal processor + results = processor( + text=examples["text"], + images=examples["image"], + padding="max_length", + truncation=True, + max_length=1024, # longer than for text-only b/c images require quite a few tokens + return_tensors="pt", + ) + return results + + train_dataset = ( + load_dataset("text", data_files={"train": UNTOKENIZED_DATASET_PATH}, split="train") + .to_iterable_dataset() # only map examples as-needed and on-demand + .map(generate_procedural_image, with_indices=True) + .map(apply_chat_template) + .map(preprocess_function, remove_columns=["text", "image"]) + ) + return train_dataset + + +def create_model(model_name): + """ + Create a mini version model + The commented values are the original values + """ + model_config = MINI_MODEL_SETUPS[model_name].mini_model_config + model_class = MINI_MODEL_SETUPS[model_name].model_class + return model_class(model_config) + + +@require_deterministic +def run_mini_model_multimodal( + model_name="mini_qwen2_vl", + num_steps=100, + dtype=torch.bfloat16, + lr=1e-5, + with_liger=False, +): + # If we move it to the beginning of test_mini_model, the two runs are initialized with different weights. + # This is due to RNG (Random Number Generator). The formula of RNG progression is x_(n+1) = (a * x_n + c) % m + # Everytime RNG is used, like randomly initialzing weight, the RNG progresses to the next state. + # Therefore, we have to reset RNG before we create the model to ensure the weight initialization started from the same RNG state. + + set_seed(42) + + revert_kwargs = {"model_config": MINI_MODEL_SETUPS[model_name]} + if "mllama" in model_name or "llama4" in model_name or "qwen3_5" in model_name: + revert_kwargs["model_type"] = "conditional_generation" + + if with_liger is True: + kwargs = { + "rope": True, + "rms_norm": True, + "cross_entropy": False, + } + + if ( + "qwen2_5_vl" not in model_name + and "llava" not in model_name + and "qwen3_vl" not in model_name + and "qwen3_5" not in model_name + ): + kwargs["layer_norm"] = True + + if "qwen3_5" in model_name: + kwargs["rope"] = False + + if "gemma" in model_name: + kwargs["geglu"] = True + else: + kwargs["swiglu"] = True + + if "llava" in model_name: + apply_liger_kernel_to_llama(**kwargs) + + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) + else: + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs) + + model = create_model(model_name).to(dtype).to(device) + model.gradient_checkpointing_enable() + + train_dataset = create_multimodal_dataset(model_name) + loader = DataLoader(train_dataset, batch_size=2, shuffle=False, collate_fn=multimodal_collate_fn) + loader_iter = iter(loader) + optimizer = torch.optim.AdamW(model.parameters(), lr=lr) + + loss_list = [] + + for i in range(num_steps): + batch = next(loader_iter).to(model.device) + optimizer.zero_grad() + supports_accum = getattr(model, "_supports_accum_dtype", None) + if supports_accum is None: + import inspect + + params = inspect.signature(model.forward).parameters + supports_accum = ("accum_dtype" in params) or any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values() + ) + setattr(model, "_supports_accum_dtype", supports_accum) + + output = model(**batch, accum_dtype=torch.float32) if supports_accum else model(**batch) + output.loss.backward() + optimizer.step() + + print(f"Step {i}, Loss: {output.loss.item()}") + loss_list.append(output.loss.item()) + + model.eval() + eval_batch = next(loader_iter).to(model.device) + if with_liger: + eval_batch["skip_logits"] = False + with torch.no_grad(): + eval_output = model(**eval_batch) + print(f"Eval Loss: {eval_output.loss.item()}") + loss_list.append(eval_output.loss.item()) + topk_logprobs = get_topk(get_logprobs(eval_output.logits)) + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs) + return { + "loss": loss_list, + "topk_logprobs": topk_logprobs.values, + "model": model, + } + + +@pytest.mark.parametrize( + "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logprobs_atol, logprobs_rtol, param_atol, param_rtol", + [ + pytest.param( + "mini_qwen2_vl", + 32, + 1e-4, + torch.bfloat16, + 5e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not QWEN2_VL_AVAILABLE, + reason="Qwen2-VL not available in this version of transformers", + ), + pytest.mark.skipif(not is_torchvision_available(), reason="Qwen2VLVideoProcessor requires torchvision"), + ], + ), + pytest.param( + "mini_llava", + 32, + 1e-5, + torch.bfloat16, + 5e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not LLAVA_AVAILABLE, + reason="LLaVa not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_internvl", + 32, + 1e-5, + torch.bfloat16, + 5e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not INTERNVL_AVAILABLE, + reason="InternVL not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_smolvlm2", + 32, + 1e-5, + torch.bfloat16, + 5e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not SMOLVLM2_AVAILABLE, + reason="SmolVLM2 not available in this version of transformers", + ), + pytest.mark.skipif( + not NUM2WORDS_AVAILABLE, + reason="num2words must be present to run SmolVLMProcessor", + ), + ], + ), + pytest.param( + "mini_qwen2_5_vl", + 32, + 1e-5, + torch.bfloat16, + 5e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not QWEN2_5_VL_AVAILABLE, + reason="Qwen2.5-VL not available in this version of transformers", + ), + pytest.mark.skipif(not is_torchvision_available(), reason="Qwen2VLVideoProcessor requires torchvision"), + ], + ), + pytest.param( + "mini_qwen3_vl", + 32, + 1e-5, + torch.bfloat16, + 5e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not QWEN3_VL_AVAILABLE, + reason="Qwen3-VL not available in this version of transformers", + ), + pytest.mark.skipif( + not is_torchvision_available(), + reason="Qwen3VLVideoProcessor requires torchvision", + ), + ], + ), + pytest.param( + "mini_qwen3_vl_moe", + 32, + 1e-5, + torch.bfloat16, + 5e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not QWEN3_VL_MOE_AVAILABLE, + reason="Qwen3-VL-MoE not available in this version of transformers", + ), + pytest.mark.skipif( + not is_torchvision_available(), + reason="Qwen3VLVideoProcessor requires torchvision", + ), + ], + ), + pytest.param( + "mini_mllama", + 32, + 1e-5, + torch.bfloat16, + 5e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not MLLAMA_AVAILABLE, + reason="Mllama not available in this version of transformers", + ), + pytest.mark.skipif( + version.parse("4.51.0") > version.parse(transformers.__version__), + reason="MllamaForConditionalGeneration doesn't accecpt `skip_logits` kwargs", + ), + ], + ), + pytest.param( + "mini_llama4", + 32, + 1e-5, + torch.bfloat16, + 5e-2, + 5e-2, + 1e-1, + 1e-1, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not LLAMA4_AVAILABLE, + reason="Llama4 not available in this version of transformers", + ), + # TODO: Remove this skipif when the bug fix is released in Transformers + pytest.mark.skipif( + version.parse(transformers.__version__) <= version.parse("5.1.0"), + reason="Wait for this bug fix to be released in Transformers: https://github.com/huggingface/transformers/pull/43882", + ), + ], + ), + pytest.param( + "mini_paligemma", + 32, + 1e-5, + torch.bfloat16, + 5e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not PALIGEMMA_AVAILABLE, + reason="Paligemma not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_paligemma2", + 32, + 1e-5, + torch.bfloat16, + 5e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not PALIGEMMA_AVAILABLE, + reason="Paligemma2 not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_gemma3", + 32, + 1e-5, + torch.bfloat16, + 5e-2, + 5e-2, + 1e-1, + 1e-1, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not GEMMA3_AVAILABLE, + reason="Gemma3 not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_qwen3_5", + 32, + 1e-5, + torch.bfloat16, + 5e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not QWEN3_5_AVAILABLE, + reason="Qwen3.5 not available in this version of transformers", + ), + pytest.mark.skipif( + not is_torchvision_available(), + reason="Qwen3VLVideoProcessor requires torchvision", + ), + ], + ), + ], +) +def test_mini_model_multimodal( + model_name, + num_steps, + lr, + dtype, + loss_atol, + loss_rtol, + logprobs_atol, + logprobs_rtol, + param_atol, + param_rtol, +): + # Non-liger models should be initialized and tested first to avoid the module being overridden + expected_output = run_mini_model_multimodal(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr) + + actual_output = run_mini_model_multimodal( + model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True + ) + + # Compare the loss of every step + assert_verbose_allclose( + torch.tensor([expected_output["loss"]]), + torch.tensor([actual_output["loss"]]), + atol=loss_atol, + rtol=loss_rtol, + extra_info="[Loss]", + ) + + # Compare the topk logprobs from evaluation step + assert_verbose_allclose( + expected_output["topk_logprobs"], + actual_output["topk_logprobs"], + atol=logprobs_atol, + rtol=logprobs_rtol, + extra_info="[Top k logprobs]", + ) + + # Compare the params from the last step + # Iterate over the model's parameters and compare them + for expected_param, actual_param in zip( + expected_output["model"].named_parameters(), + actual_output["model"].named_parameters(), + ): + assert_verbose_allclose( + expected_param[1], + actual_param[1], + atol=param_atol, + rtol=param_rtol, + extra_info="[Model parameters]", + ) + + +# +# Vision-only model tests (e.g. Pixtral vision encoder) +# + + +def generate_procedural_pixel_values(batch_size, num_channels, image_size, index, dtype, device): + """Generate deterministic pixel values for vision-only model testing. + + Each image has a single row of white pixels at a deterministic position, + providing a reproducible signal for convergence testing. + """ + pixel_values = torch.zeros(batch_size, num_channels, image_size, image_size, dtype=dtype, device=device) + for b in range(batch_size): + row = (index + b) % image_size + pixel_values[b, :, row, :] = 1.0 + return pixel_values + + +@require_deterministic +def run_mini_model_vision( + model_name="mini_pixtral", + num_steps=100, + dtype=torch.bfloat16, + lr=1e-5, + with_liger=False, +): + set_seed(42) + + revert_kwargs = {"model_config": MINI_MODEL_SETUPS[model_name]} + + if with_liger is True: + kwargs = { + "rope": True, + "rms_norm": True, + "swiglu": True, + } + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) + else: + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs) + + model = create_model(model_name).to(dtype).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=lr) + + loss_list = [] + + for i in range(num_steps): + optimizer.zero_grad() + pixel_values = generate_procedural_pixel_values( + batch_size=2, + num_channels=model.config.num_channels, + image_size=model.config.image_size, + index=i, + dtype=dtype, + device=device, + ) + output = model(pixel_values=pixel_values) + loss = output.last_hidden_state.sum() + loss.backward() + optimizer.step() + print(f"Step {i}, Loss: {loss.item()}") + loss_list.append(loss.item()) + + # Eval step with deterministic input + model.eval() + with torch.no_grad(): + eval_pixel_values = generate_procedural_pixel_values( + batch_size=2, + num_channels=model.config.num_channels, + image_size=model.config.image_size, + index=num_steps, + dtype=dtype, + device=device, + ) + eval_output = model(pixel_values=eval_pixel_values) + + topk_logprobs = get_topk(get_logprobs(eval_output.last_hidden_state)) + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs) + return { + "loss": loss_list, + "topk_logprobs": topk_logprobs.values, + "model": model, + } + + +@pytest.mark.parametrize( + "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logprobs_atol, logprobs_rtol, param_atol, param_rtol", + [ + pytest.param( + "mini_pixtral", + 32, + 1e-4, + torch.bfloat16, + 1e-3, + 1e-2, + 1e-0, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not PIXTRAL_AVAILABLE, reason="Pixtral not available in this version of transformers" + ), + ], + ), + ], +) +def test_mini_model_vision( + model_name, + num_steps, + lr, + dtype, + loss_atol, + loss_rtol, + logprobs_atol, + logprobs_rtol, + param_atol, + param_rtol, +): + # Non-liger models should be initialized and tested first to avoid the module being overridden + expected_output = run_mini_model_vision(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr) + + actual_output = run_mini_model_vision( + model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True + ) + + # Compare the loss of every step + assert_verbose_allclose( + torch.tensor([expected_output["loss"]]), + torch.tensor([actual_output["loss"]]), + atol=loss_atol, + rtol=loss_rtol, + extra_info="[Loss]", + ) + + # Compare the topk logprobs from evaluation step + assert_verbose_allclose( + expected_output["topk_logprobs"], + actual_output["topk_logprobs"], + atol=logprobs_atol, + rtol=logprobs_rtol, + extra_info="[Top k logprobs]", + ) + + # Compare the params from the last step + for expected_param, actual_param in zip( + expected_output["model"].named_parameters(), + actual_output["model"].named_parameters(), + ): + assert_verbose_allclose( + expected_param[1], + actual_param[1], + atol=param_atol, + rtol=param_rtol, + extra_info="[Model parameters]", + ) diff --git a/test/convergence/bf16/test_mini_models_with_logits.py b/test/convergence/bf16/test_mini_models_with_logits.py new file mode 100755 index 0000000000000000000000000000000000000000..ff0d12304f8af5adca7707686cf309101ab7ecd8 --- /dev/null +++ b/test/convergence/bf16/test_mini_models_with_logits.py @@ -0,0 +1,2167 @@ +import os + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # Ensure deterministic behavior with CuBLAS + +import pytest +import torch +import transformers + +from datasets import load_from_disk +from packaging import version +from torch.utils.data import DataLoader +from transformers.models.gemma import GemmaConfig +from transformers.models.gemma import GemmaForCausalLM +from transformers.models.gemma2 import Gemma2Config +from transformers.models.gemma2 import Gemma2ForCausalLM +from transformers.models.llama import LlamaConfig +from transformers.models.llama import LlamaForCausalLM +from transformers.models.mistral import MistralConfig +from transformers.models.mistral import MistralForCausalLM +from transformers.models.mixtral import MixtralConfig +from transformers.models.mixtral import MixtralForCausalLM +from transformers.models.phi3 import Phi3Config +from transformers.models.phi3 import Phi3ForCausalLM +from transformers.models.qwen2 import Qwen2Config +from transformers.models.qwen2 import Qwen2ForCausalLM + +from liger_kernel.transformers import apply_liger_kernel_to_exaone4 +from liger_kernel.transformers import apply_liger_kernel_to_falcon_h1 +from liger_kernel.transformers import apply_liger_kernel_to_gemma +from liger_kernel.transformers import apply_liger_kernel_to_gemma2 +from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text +from liger_kernel.transformers import apply_liger_kernel_to_glm4 +from liger_kernel.transformers import apply_liger_kernel_to_glm4v +from liger_kernel.transformers import apply_liger_kernel_to_glm4v_moe +from liger_kernel.transformers import apply_liger_kernel_to_granite +from liger_kernel.transformers import apply_liger_kernel_to_hunyuan_v1_dense +from liger_kernel.transformers import apply_liger_kernel_to_hunyuan_v1_moe +from liger_kernel.transformers import apply_liger_kernel_to_internvl +from liger_kernel.transformers import apply_liger_kernel_to_llama +from liger_kernel.transformers import apply_liger_kernel_to_llama4 +from liger_kernel.transformers import apply_liger_kernel_to_llava +from liger_kernel.transformers import apply_liger_kernel_to_mistral +from liger_kernel.transformers import apply_liger_kernel_to_mixtral +from liger_kernel.transformers import apply_liger_kernel_to_mllama +from liger_kernel.transformers import apply_liger_kernel_to_olmo2 +from liger_kernel.transformers import apply_liger_kernel_to_olmo3 +from liger_kernel.transformers import apply_liger_kernel_to_phi3 +from liger_kernel.transformers import apply_liger_kernel_to_qwen2 +from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl +from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl +from liger_kernel.transformers import apply_liger_kernel_to_qwen3 +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5 +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_moe +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_next +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl_moe +from liger_kernel.transformers import apply_liger_kernel_to_smollm3 +from liger_kernel.utils import infer_device +from test.utils import DEFAULT_DATASET_PATH +from test.utils import MiniModelConfig +from test.utils import assert_verbose_allclose +from test.utils import get_logprobs +from test.utils import get_topk +from test.utils import require_deterministic +from test.utils import revert_liger_kernel_to_exaone4 +from test.utils import revert_liger_kernel_to_falcon_h1 +from test.utils import revert_liger_kernel_to_gemma +from test.utils import revert_liger_kernel_to_gemma2 +from test.utils import revert_liger_kernel_to_gemma3_text +from test.utils import revert_liger_kernel_to_glm4 +from test.utils import revert_liger_kernel_to_glm4v +from test.utils import revert_liger_kernel_to_glm4v_moe +from test.utils import revert_liger_kernel_to_granite +from test.utils import revert_liger_kernel_to_hunyuan_v1 +from test.utils import revert_liger_kernel_to_hunyuan_v1_moe +from test.utils import revert_liger_kernel_to_internvl +from test.utils import revert_liger_kernel_to_llama +from test.utils import revert_liger_kernel_to_llama4 +from test.utils import revert_liger_kernel_to_llava +from test.utils import revert_liger_kernel_to_mistral +from test.utils import revert_liger_kernel_to_mixtral +from test.utils import revert_liger_kernel_to_mllama +from test.utils import revert_liger_kernel_to_olmo2 +from test.utils import revert_liger_kernel_to_olmo3 +from test.utils import revert_liger_kernel_to_phi3 +from test.utils import revert_liger_kernel_to_qwen2 +from test.utils import revert_liger_kernel_to_qwen2_5_vl +from test.utils import revert_liger_kernel_to_qwen2_vl +from test.utils import revert_liger_kernel_to_qwen3 +from test.utils import revert_liger_kernel_to_qwen3_5 +from test.utils import revert_liger_kernel_to_qwen3_moe +from test.utils import revert_liger_kernel_to_qwen3_next +from test.utils import revert_liger_kernel_to_qwen3_vl +from test.utils import revert_liger_kernel_to_qwen3_vl_moe +from test.utils import revert_liger_kernel_to_smollm3 +from test.utils import set_seed +from test.utils import simple_collate_fn +from test.utils import supports_bfloat16 + +IS_TRANSFORMERS_V5_OR_LATER = version.parse(transformers.__version__) >= version.parse("5.0.0") + +try: + from transformers.models.llama4.configuration_llama4 import Llama4TextConfig + from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM + + LLAMA4_AVAILABLE = True +except ImportError: + LLAMA4_AVAILABLE = False + +try: + # Mllama is only available in transformers>=4.45.0 + from transformers.models.mllama.configuration_mllama import MllamaTextConfig + from transformers.models.mllama.modeling_mllama import MllamaForCausalLM + + MLLAMA_AVAILABLE = True +except ImportError: + MLLAMA_AVAILABLE = False + +try: + # Qwen2-VL is only available in transformers>4.52.4 + import transformers + + from packaging import version + from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration + + QWEN2_VL_AVAILABLE = version.parse(transformers.__version__) >= version.parse("4.52.4") +except ImportError: + QWEN2_VL_AVAILABLE = False + +try: + # Qwen2.5-VL is only available in transformers>4.52.4 + import transformers + + from packaging import version + from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration + + QWEN2_5_VL_AVAILABLE = version.parse(transformers.__version__) >= version.parse("4.52.4") +except ImportError: + QWEN2_5_VL_AVAILABLE = False + +try: + from transformers.models.qwen3.configuration_qwen3 import Qwen3Config + from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM + from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig + from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeForCausalLM + + QWEN3_AVAILABLE = True +except ImportError: + QWEN3_AVAILABLE = False + +try: + from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration + + QWEN3_VL_AVAILABLE = True +except ImportError: + QWEN3_VL_AVAILABLE = False + +try: + import transformers + + from packaging import version + from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeConfig + from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeTextConfig + from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeVisionConfig + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration + + QWEN3_VL_MOE_AVAILABLE = version.parse(transformers.__version__) >= version.parse("4.57.0") +except ImportError: + QWEN3_VL_MOE_AVAILABLE = False + +try: + from transformers.models.granite import GraniteConfig + from transformers.models.granite import GraniteForCausalLM + + GRANITE_AVAILABLE = True +except ImportError: + GRANITE_AVAILABLE = False + +try: + from transformers import CLIPVisionConfig + from transformers.models.llava.configuration_llava import LlavaConfig + from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration + + LLAVA_AVAILABLE = True +except ImportError: + LLAVA_AVAILABLE = False + +try: + # OLMO2 is only available in transformers>=4.47.0 + from transformers.models.olmo2.configuration_olmo2 import Olmo2Config + from transformers.models.olmo2.modeling_olmo2 import Olmo2ForCausalLM + + OLMO2_AVAILABLE = True +except ImportError: + OLMO2_AVAILABLE = False + +try: + # OLMO3 is only available in transformers>=4.57.0 + from transformers.models.olmo3.configuration_olmo3 import Olmo3Config + from transformers.models.olmo3.modeling_olmo3 import Olmo3ForCausalLM + + OLMO3_AVAILABLE = True +except ImportError: + OLMO3_AVAILABLE = False + +try: + # Glm4 is only available in transformers>=4.51.3 + from transformers.models.glm4.configuration_glm4 import Glm4Config + from transformers.models.glm4.modeling_glm4 import Glm4ForCausalLM + + GLM4_AVAILABLE = True +except ImportError: + GLM4_AVAILABLE = False + +try: + # Glm4v is only available in transformers>=4.51.3 + from transformers.models.glm4v.configuration_glm4v import Glm4vConfig + from transformers.models.glm4v.modeling_glm4v import Glm4vForConditionalGeneration + + GLM4V_AVAILABLE = True +except ImportError: + GLM4V_AVAILABLE = False + +try: + # Glm4v_moe is only available in transformers>=4.51.3 + from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig + from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeForConditionalGeneration + + GLM4V_MOE_AVAILABLE = version.parse(transformers.__version__) >= version.parse("4.53.1") +except ImportError: + GLM4V_MOE_AVAILABLE = False + +try: + from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig + from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM + + GEMMA3_AVAILABLE = True +except ImportError: + GEMMA3_AVAILABLE = False + +try: + # Smollm3 is only available in transformers>=4.53.0 + from transformers.models.smollm3.configuration_smollm3 import SmolLM3Config + from transformers.models.smollm3.modeling_smollm3 import SmolLM3ForCausalLM + + SMOLLM3_AVAILABLE = True +except ImportError: + SMOLLM3_AVAILABLE = False + +try: + # InternVL is only available in transformers>=4.52.1 + from transformers.models.internvl.configuration_internvl import InternVLConfig + from transformers.models.internvl.modeling_internvl import InternVLForConditionalGeneration + + INTERNVL_AVAILABLE = True +except ImportError: + INTERNVL_AVAILABLE = False + +try: + # FalconH1 is only available in transformers>=4.53.0 + from transformers.models.falcon_h1.configuration_falcon_h1 import FalconH1Config + from transformers.models.falcon_h1.modeling_falcon_h1 import FalconH1ForCausalLM + + FALCONH1_AVAILABLE = True +except ImportError: + FALCONH1_AVAILABLE = False + +try: + # Qwen3Next is only available in transformers>=4.57.0 + from transformers.models.qwen3_next.configuration_qwen3_next import Qwen3NextConfig + from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextForCausalLM + + QWEN3NEXT_AVAILABLE = True +except ImportError: + QWEN3NEXT_AVAILABLE = False + +try: + from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM + + QWEN3_5_AVAILABLE = True +except ImportError: + QWEN3_5_AVAILABLE = False + +try: + from transformers.models.hunyuan_v1_dense.configuration_hunyuan_v1_dense import HunYuanDenseV1Config + from transformers.models.hunyuan_v1_dense.modeling_hunyuan_v1_dense import HunYuanDenseV1ForCausalLM + from transformers.models.hunyuan_v1_moe.configuration_hunyuan_v1_moe import HunYuanMoEV1Config + from transformers.models.hunyuan_v1_moe.modeling_hunyuan_v1_moe import HunYuanMoEV1ForCausalLM + + HUNYUAN_V1_AVAILABLE = True +except ImportError: + HUNYUAN_V1_AVAILABLE = False + +try: + from transformers.models.exaone4.configuration_exaone4 import Exaone4Config + from transformers.models.exaone4.modeling_exaone4 import Exaone4ForCausalLM + + EXAONE4_AVAILABLE = True +except ImportError: + EXAONE4_AVAILABLE = False + + +device = infer_device() + +MINI_MODEL_SETUPS = { + "mini_llama3": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_llama, + liger_kernel_patch_revert_func=revert_liger_kernel_to_llama, + model_class=LlamaForCausalLM, + mini_model_config=LlamaConfig( + attention_bias=False, + attention_dropout=0.0, + # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=8192, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + num_key_value_heads=2, # 8 + pretraining_tp=1, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ), + "mini_qwen2": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen2, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen2, + model_class=Qwen2ForCausalLM, + mini_model_config=Qwen2Config( + attention_dropout=0.0, + bos_token_id=1, # 151643 + eos_token_id=2, # 151643 + hidden_act="silu", + hidden_size=896, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=32768, # 131072 + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-6, + sliding_window=131072, + tie_word_embeddings=True, + use_cache=True, + vocab_size=32000, # 151936 + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ), + "mini_phi3": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_phi3, + liger_kernel_patch_revert_func=revert_liger_kernel_to_phi3, + model_class=Phi3ForCausalLM, + mini_model_config=Phi3Config( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, # 32000 + hidden_act="silu", + hidden_size=896, # 3072 + initializer_range=0.02, + intermediate_size=4864, # 8192 + max_position_embeddings=4096, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + num_key_value_heads=None, # defaults to num_attention_heads + rms_norm_eps=1e-5, + sliding_window=None, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32064, + attn_implementation="eager", + ), + ), + "mini_mistral": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_mistral, + liger_kernel_patch_revert_func=revert_liger_kernel_to_mistral, + model_class=MistralForCausalLM, + mini_model_config=MistralConfig( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=2048, + max_position_embeddings=32768, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-5, + sliding_window=4096, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, + attn_implementation="sdpa", + ), + ), + "mini_mixtral": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_mixtral, + liger_kernel_patch_revert_func=revert_liger_kernel_to_mixtral, + model_class=MixtralForCausalLM, + mini_model_config=MixtralConfig( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=512, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=32768, # 32768 + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-5, + sliding_window=4096, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, + attn_implementation="sdpa", + ), + ), + "mini_gemma1": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_gemma, + liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma, + model_class=GemmaForCausalLM, + mini_model_config=GemmaConfig( + vocab_size=32000, # 256000 + hidden_size=1024, # 3072 + intermediate_size=2048, # 24576 + num_hidden_layers=4, # 28 + num_attention_heads=4, # 16 + num_key_value_heads=4, # 16 + head_dim=256, + # gemma1 model config uses `hidden_act` and point it to gelu, + # https://huggingface.co/google/gemma-7b/blob/main/config.json#L10 + # but in reality it's ignored and HuggingFace will use tanh approximation: + # https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/gemma/modeling_gemma.py#L175 + hidden_act="gelu", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + pad_token_id=0, + # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + tie_word_embeddings=True, + attention_bias=False, + attention_dropout=0.0, + ), + ), + "mini_gemma1.1": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_gemma, + liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma, + model_class=GemmaForCausalLM, + mini_model_config=GemmaConfig( + vocab_size=32000, # 256000 + hidden_size=1024, # 3072 + intermediate_size=2048, # 24576 + num_hidden_layers=4, # 28 + num_attention_heads=4, # 16 + num_key_value_heads=4, # 16 + head_dim=256, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + pad_token_id=0, + # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + tie_word_embeddings=True, + attention_bias=False, + attention_dropout=0.0, + ), + ), + "mini_gemma2": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_gemma2, + liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma2, + model_class=Gemma2ForCausalLM, + mini_model_config=Gemma2Config( + vocab_size=32000, # 256000 + hidden_size=1024, # 3072 + intermediate_size=2048, # 24576 + num_hidden_layers=4, # 28 + num_attention_heads=4, # 16 + num_key_value_heads=4, # 16 + head_dim=256, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + pad_token_id=0, + # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + tie_word_embeddings=True, + attention_bias=False, + attention_dropout=0.0, + attn_implementation="eager", + ), + ), +} + +if LLAMA4_AVAILABLE: + MINI_MODEL_SETUPS["mini_llama4"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_llama4, + liger_kernel_patch_revert_func=revert_liger_kernel_to_llama4, + model_class=Llama4ForCausalLM, + mini_model_config=Llama4TextConfig( + bos_token_id=1, # None + eos_token_id=2, # 151329, 151336, 151338 + pad_token_id=2, # 151329 + partial_rotary_factor=1.0, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 6144 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, # 32768 + num_attention_heads=8, # 48 + num_hidden_layers=4, # 61 + num_key_value_heads=2, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 151552 + attention_bias=True, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if QWEN3_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3, + model_class=Qwen3ForCausalLM, + mini_model_config=Qwen3Config( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=896, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=32768, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-6, + sliding_window=131072, + tie_word_embeddings=True, + use_cache=True, + vocab_size=32000, + attn_implementation="sdpa", + ), + ) + + MINI_MODEL_SETUPS["mini_qwen3_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_moe, + model_class=Qwen3MoeForCausalLM, + mini_model_config=Qwen3MoeConfig( + vocab_size=32000, # 151936 + hidden_size=896, + intermediate_size=4864, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + attention_bias=False, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + decoder_sparse_step=1, + moe_intermediate_size=768, + num_experts_per_tok=2, + num_experts=8, + norm_topk_prob=False, + output_router_logits=False, + router_aux_loss_coef=0.001, + mlp_only_layers=None, + ), + ) + +if QWEN3_VL_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3_vl"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_vl, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_vl, + model_class=Qwen3VLForConditionalGeneration, + mini_model_config=Qwen3VLConfig( + tie_word_embeddings=False, + image_token_id=31997, + video_token_id=31998, + vision_start_token_id=31995, + vision_end_token_id=31996, + text_config=dict( + attention_dropout=0.0, + attn_implementation="sdpa", + bos_token_id=1, + eos_token_id=2, + head_dim=112, + hidden_act="silu", + hidden_size=896, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=32768, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + pad_token_id=2, + rms_norm_eps=1e-6, + sliding_window=131072, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, + rope_scaling=dict( + type="mrope", + mrope_section=[16, 24, 24], # (temporal, height, width) + ) + if not IS_TRANSFORMERS_V5_OR_LATER + else None, + ), + vision_config=dict( + depth=4, + hidden_size=128, + initializer_range=0.02, + intermediate_size=256, + num_heads=8, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=896, + num_position_embeddings=576, + deepstack_visual_indexes=[1, 2, 3], + ), + ), + ) + +if QWEN3_VL_MOE_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3_vl_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_vl_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_vl_moe, + model_class=Qwen3VLMoeForConditionalGeneration, + mini_model_config=Qwen3VLMoeConfig( + tie_word_embeddings=False, + image_token_id=31997, + video_token_id=31998, + vision_start_token_id=31995, + vision_end_token_id=31996, + text_config=Qwen3VLMoeTextConfig( + attention_dropout=0.0, + attention_bias=False, + attn_implementation="sdpa", + bos_token_id=1, + eos_token_id=2, + head_dim=112, + hidden_act="silu", + hidden_size=896, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=32768, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + pad_token_id=2, + rms_norm_eps=1e-6, + sliding_window=131072, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, + decoder_sparse_step=1, + moe_intermediate_size=3072, + num_experts_per_tok=2, + num_experts=4, + mlp_only_layers=[], + ).to_dict(), + vision_config=Qwen3VLMoeVisionConfig( + depth=4, + hidden_size=128, + initializer_range=0.02, + intermediate_size=256, + num_heads=8, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=896, + num_position_embeddings=576, + deepstack_visual_indexes=[1, 2, 3], + ).to_dict(), + ), + ) + +if GEMMA3_AVAILABLE: + MINI_MODEL_SETUPS["mini_gemma3_text"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_gemma3_text, + liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma3_text, + model_class=Gemma3ForCausalLM, + mini_model_config=Gemma3TextConfig( + vocab_size=32000, # 262144 + hidden_size=1024, # 1152 + intermediate_size=2048, # 6912 + num_hidden_layers=4, # 26 + num_attention_heads=4, + num_key_value_heads=1, + head_dim=256, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=8192, # 32768 + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + pad_token_id=0, + bos_token_id=2, + eos_token_id=1, + tie_word_embeddings=True, + attention_bias=False, + attention_dropout=0.0, + attn_implementation="eager", + ), + ) + +if MLLAMA_AVAILABLE: + MINI_MODEL_SETUPS["mini_mllama"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_mllama, + liger_kernel_patch_revert_func=revert_liger_kernel_to_mllama, + model_class=MllamaForCausalLM, + mini_model_config=MllamaTextConfig( + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + pad_token_id=2, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=131_072, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 40 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + attn_implementation="sdpa", # default value, pytorch native attention + rope_scaling=dict( + factor=8.0, + high_freq_factor=4.0, + low_freq_factor=1.0, + original_max_position_embeddings=8192, + rope_type="llama3", + rope_theta=500_000, + ) + if not IS_TRANSFORMERS_V5_OR_LATER + else None, + ), + ) + +if QWEN2_VL_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen2_vl"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen2_vl, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen2_vl, + model_class=Qwen2VLForConditionalGeneration, + mini_model_config=Qwen2VLConfig( + attention_dropout=0.0, + # bos and eos set to match the Mistral-7B tokenizer used to create the test dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 151643 + eos_token_id=2, # 151645 + vision_start_token_id=32765, # vocab_size - 5 + vision_end_token_id=32766, # vocab_size - 4 + vision_token_id=32767, # vocab_size - 3 + image_token_id=32768, # vocab_size - 2 + video_token_id=32769, # vocab_size - 1 + hidden_act="silu", + hidden_size=1536, # 8192 + initializer_range=0.02, + intermediate_size=4864, # 29568 + max_position_embeddings=32768, + max_window_layers=4, # 80 + num_attention_heads=12, # 64 + num_hidden_layers=4, # 80 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-6, # 1e-5 + **( + dict(rope_parameters=dict(mrope_section=[16, 24, 24])) # (temporal, height, width) + if IS_TRANSFORMERS_V5_OR_LATER + else dict(rope_scaling=dict(type="mrope", mrope_section=[16, 24, 24])) + ), + sliding_window=4096, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32768, # 152064 # >32k, Mistral-7B tokenizer vocab size + use_sliding_window=False, + vision_config={ + "depth": 4, # 32 + "embed_dim": 1280, + "mlp_ratio": 4, + "num_heads": 16, + "in_chans": 3, + "hidden_size": 128, # 1536 + "patch_size": 14, + "spatial_merge_size": 2, + "spatial_patch_size": 14, + "temporal_patch_size": 2, + }, + attn_implementation="sdpa", + ), + ) + +if QWEN2_5_VL_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen2_5_vl"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen2_5_vl, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen2_5_vl, + model_class=Qwen2_5_VLForConditionalGeneration, + mini_model_config=Qwen2_5_VLConfig( + attention_dropout=0.0, + # bos and eos set to match the Mistral-7B tokenizer used to create the test dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 151643 + eos_token_id=2, # 151645 + vision_start_token_id=32765, # vocab_size - 5 + vision_end_token_id=32766, # vocab_size - 4 + vision_token_id=32767, # vocab_size - 3 + image_token_id=32768, # vocab_size - 2 + video_token_id=32769, # vocab_size - 1 + hidden_act="silu", + hidden_size=1536, # 8192 + initializer_range=0.02, + intermediate_size=4864, # 29568 + max_position_embeddings=32768, + max_window_layers=4, # 80 + num_attention_heads=12, # 64 + num_hidden_layers=4, # 80 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-6, # 1e-5 + **( + dict(rope_parameters=dict(mrope_section=[16, 24, 24])) # (temporal, height, width) + if IS_TRANSFORMERS_V5_OR_LATER + else dict(rope_scaling=dict(type="mrope", mrope_section=[16, 24, 24])) + ), + sliding_window=4096, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32768, # 152064 # >32k, Mistral-7B tokenizer vocab size + use_sliding_window=False, + vision_config={ + "depth": 4, # 32 + "hidden_act": "silu", + "hidden_size": 128, # 1280 + "intermediate_size": 256, # 3420 + "num_heads": 16, + "in_chans": 3, + "out_hidden_size": 128, # 3584 + "patch_size": 14, + "spatial_merge_size": 2, + "spatial_patch_size": 14, + "window_size": 112, + "fullatt_block_indexes": [7, 15, 23, 31], + "tokens_per_second": 2, + "temporal_patch_size": 2, + }, + attn_implementation="sdpa", + ), + ) + +if GRANITE_AVAILABLE: + MINI_MODEL_SETUPS["mini_granite3"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_granite, + liger_kernel_patch_revert_func=revert_liger_kernel_to_granite, + model_class=GraniteForCausalLM, + mini_model_config=GraniteConfig( + attention_bias=False, + attention_dropout=0.0, + # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=8192, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + num_key_value_heads=2, # 8 + pretraining_tp=1, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + logits_scaling=8.0, + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if LLAVA_AVAILABLE: + # https://huggingface.co/llava-hf/llava-1.5-7b-hf + MINI_MODEL_SETUPS["mini_llava"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_llava, + liger_kernel_patch_revert_func=revert_liger_kernel_to_llava, + model_class=LlavaForConditionalGeneration, + mini_model_config=LlavaConfig( + text_config=LlamaConfig( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=2048, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + pretraining_tp=1, + tie_word_embeddings=False, + use_cache=True, + max_position_embeddings=4096, # llava-1.5-7b-hf + rms_norm_eps=1e-05, # llava-1.5-7b-hf + vocab_size=32064, # llava-1.5-7b-hf + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + vision_config=CLIPVisionConfig( + hidden_size=1024, + image_size=336, + intermediate_size=2048, # 4096 + model_type="clip_vision_model", + num_attention_heads=4, # 16 + num_hidden_layers=4, # 24 + patch_size=14, + projection_dim=768, + vocab_size=32000, + ), + vocab_size=32064, + ignore_index=-100, + pad_token_id=4, + image_token_index=3, + projector_hidden_act="gelu", + vision_feature_layer=-2, + vision_feature_select_strategy="default", + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if OLMO2_AVAILABLE: + MINI_MODEL_SETUPS["mini_olmo2"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_olmo2, + liger_kernel_patch_revert_func=revert_liger_kernel_to_olmo2, + model_class=Olmo2ForCausalLM, + mini_model_config=Olmo2Config( + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + pad_token_id=2, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 40 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if OLMO3_AVAILABLE: + MINI_MODEL_SETUPS["mini_olmo3"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_olmo3, + liger_kernel_patch_revert_func=revert_liger_kernel_to_olmo3, + model_class=Olmo3ForCausalLM, + mini_model_config=Olmo3Config( + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + pad_token_id=2, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 40 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if GLM4_AVAILABLE: + MINI_MODEL_SETUPS["mini_glm4"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_glm4, + liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4, + model_class=Glm4ForCausalLM, + mini_model_config=Glm4Config( + bos_token_id=1, # None + eos_token_id=2, # 151329, 151336, 151338 + pad_token_id=2, # 151329 + partial_rotary_factor=0.5, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 6144 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, # 32768 + num_attention_heads=8, # 48 + num_hidden_layers=4, # 61 + num_key_value_heads=2, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 151552 + attention_bias=True, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) +if GLM4V_AVAILABLE: + MINI_MODEL_SETUPS["mini_glm4v"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_glm4v, + liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4v, + model_class=Glm4vForConditionalGeneration, + mini_model_config=Glm4vConfig( + bos_token_id=1, # None + eos_token_id=2, # 151329, 151336, 151338 + pad_token_id=2, # 151329 + image_token_id=151343, + video_token_id=151344, + image_start_token_id=151339, + image_end_token_id=151340, + video_start_token_id=151341, + video_end_token_id=151342, + partial_rotary_factor=0.5, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 6144 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, # 32768 + num_attention_heads=8, # 48 + num_hidden_layers=4, # 61 + num_key_value_heads=2, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 151552 + attention_bias=True, + attn_implementation="sdpa", # default value, pytorch native attention + text_config={ + "partial_rotary_factor": 0.5, + "hidden_act": "silu", + "hidden_size": 1024, + "intermediate_size": 2048, + "max_position_embeddings": 4096, + "num_attention_heads": 8, + "num_hidden_layers": 4, + "num_key_value_heads": 2, + "rms_norm_eps": 1e-5, + "vocab_size": 32000, + "attention_bias": True, + **( + {"rope_scaling": {"type": "default", "mrope_section": [8, 12, 12]}} + if not IS_TRANSFORMERS_V5_OR_LATER + else {} + ), + "pad_token_id": None, + }, + vision_config={ + "depth": 4, # 32 + "hidden_act": "silu", + "hidden_size": 128, # 1280 + "intermediate_size": 256, # 3420 + "num_heads": 16, + "in_chans": 3, + "out_hidden_size": 128, # 3584 + "patch_size": 14, + "spatial_merge_size": 2, + "temporal_patch_size": 2, + }, + ), + ) + +if GLM4V_MOE_AVAILABLE: + MINI_MODEL_SETUPS["mini_glm4v_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_glm4v_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4v_moe, + model_class=Glm4vMoeForConditionalGeneration, + mini_model_config=Glm4vMoeConfig( + bos_token_id=1, # None + eos_token_id=2, # 151329, 151336, 151338 + pad_token_id=2, # 151329 + image_token_id=151343, + video_token_id=151344, + image_start_token_id=151339, + image_end_token_id=151340, + video_start_token_id=151341, + video_end_token_id=151342, + partial_rotary_factor=0.5, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 6144 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, # 32768 + num_attention_heads=8, # 48 + num_hidden_layers=4, # 61 + num_key_value_heads=2, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 151552 + attention_bias=True, + attn_implementation="sdpa", # default value, pytorch native attention + text_config={ + "partial_rotary_factor": 0.5, + "hidden_act": "silu", + "hidden_size": 1024, + "intermediate_size": 2048, + "max_position_embeddings": 4096, + "num_attention_heads": 8, + "num_hidden_layers": 4, + "num_key_value_heads": 2, + "rms_norm_eps": 1e-5, + "vocab_size": 32000, + "attention_bias": True, + "attention_dropout": 0.0, + "moe_intermediate_size": 1408, + "num_experts_per_tok": 2, + "n_shared_experts": 1, + "n_routed_experts": 8, + "routed_scaling_factor": 1.0, + "n_group": 1, + "topk_group": 1, + "first_k_dense_replace": 1, + "norm_topk_prob": True, + **( + {"rope_scaling": {"type": "default", "mrope_section": [8, 12, 12]}} + if not IS_TRANSFORMERS_V5_OR_LATER + else {} + ), + }, + vision_config={ + "depth": 4, # 32 + "hidden_act": "silu", + "hidden_size": 128, # 1280 + "intermediate_size": 256, # 3420 + "num_heads": 16, + "in_chans": 3, + "out_hidden_size": 128, # 3584 + "patch_size": 14, + "spatial_merge_size": 2, + "temporal_patch_size": 2, + }, + ), + ) + +if SMOLLM3_AVAILABLE: + MINI_MODEL_SETUPS["mini_smollm3"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_smollm3, + liger_kernel_patch_revert_func=revert_liger_kernel_to_smollm3, + model_class=SmolLM3ForCausalLM, + mini_model_config=SmolLM3Config( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + pad_token_id=2, # 128000 + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=8192, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + num_key_value_heads=2, # 8 + pretraining_tp=1, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if INTERNVL_AVAILABLE: + MINI_MODEL_SETUPS["mini_internvl"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_internvl, + liger_kernel_patch_revert_func=revert_liger_kernel_to_internvl, + model_class=InternVLForConditionalGeneration, + mini_model_config=InternVLConfig( + text_config=Qwen2Config( + rms_norm_eps=1e-5, + hidden_size=256, # 1024 + intermediate_size=1024, # 4096 + hidden_act="silu", + num_hidden_layers=4, # 24 + num_attention_heads=4, # 16 + num_key_value_heads=2, # 16 + max_position_embeddings=4096, # 8192 + vocab_size=32000, # 151936 + bos_token_id=1, + eos_token_id=2, + pad_token_id=2, + tie_word_embeddings=False, + ), + vision_config={ + "hidden_size": 256, # 1024 + "intermediate_size": 1024, # 4096 + "num_hidden_layers": 4, # 24 + "num_attention_heads": 4, # 16 + }, + image_token_id=10, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if FALCONH1_AVAILABLE: + MINI_MODEL_SETUPS["mini_falcon_h1"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_falcon_h1, + liger_kernel_patch_revert_func=revert_liger_kernel_to_falcon_h1, + model_class=FalconH1ForCausalLM, + mini_model_config=FalconH1Config( + model_type="falcon_h1", + vocab_size=32000, + hidden_size=256, # 4096 + num_hidden_layers=4, # 24 + num_attention_heads=4, # 32 + num_key_value_heads=2, # 8 + intermediate_size=1024, # 11008 + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + mamba_d_ssm=128, # 1024 + mamba_n_heads=16, # 128 + mamba_d_state=32, # 245 + mamba_d_conv=2, # 4 + ), + ) + +if QWEN3NEXT_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3_next"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_next, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_next, + model_class=Qwen3NextForCausalLM, + mini_model_config=Qwen3NextConfig( # Copypaste Qwen3MoeConfig + vocab_size=32000, + hidden_size=896, + intermediate_size=4864, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + attention_bias=False, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + decoder_sparse_step=1, + moe_intermediate_size=768, + num_experts_per_tok=2, + num_experts=8, + norm_topk_prob=False, + output_router_logits=False, + router_aux_loss_coef=0.001, + # config.dtype must be set if fla installed since there's a bug in the original code (No torch.get_current_dtype()) + # https://github.com/huggingface/transformers/blob/v4.57.1/src/transformers/models/qwen3_next/modeling_qwen3_next.py#L613 + dtype=torch.bfloat16, + ), + ) + +if QWEN3_5_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3_5"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_5, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_5, + model_class=Qwen3_5ForCausalLM, + mini_model_config=Qwen3_5TextConfig( + vocab_size=32000, + hidden_size=896, + intermediate_size=4864, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + attention_bias=False, + attention_dropout=0.0, + head_dim=128, + linear_conv_kernel_dim=4, + linear_key_head_dim=64, + linear_value_head_dim=64, + linear_num_key_heads=8, + linear_num_value_heads=8, + layer_types=["linear_attention", "linear_attention", "linear_attention", "full_attention"], + dtype=torch.bfloat16, + ), + ) + + +if HUNYUAN_V1_AVAILABLE: + MINI_MODEL_SETUPS["mini_hunyuan_v1"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_hunyuan_v1_dense, + liger_kernel_patch_revert_func=revert_liger_kernel_to_hunyuan_v1, + model_class=HunYuanDenseV1ForCausalLM, + mini_model_config=HunYuanDenseV1Config( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + num_hidden_layers=4, + hidden_size=896, + intermediate_size=4864, + num_attention_heads=8, + head_dim=112, + rms_norm_eps=1e-6, + tie_word_embeddings=True, + max_position_embeddings=32768, + initializer_range=0.02, + norm_eps=1e-6, + num_key_value_heads=2, + partial_rotary_factor=1.0, + vocab_size=32000, + use_cache=True, + attn_implementation="sdpa", + ), + ) + + MINI_MODEL_SETUPS["mini_hunyuan_v1_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_hunyuan_v1_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_hunyuan_v1_moe, + model_class=HunYuanMoEV1ForCausalLM, + mini_model_config=HunYuanMoEV1Config( + vocab_size=32000, + hidden_size=128, + intermediate_size=512, + head_dim=16, + num_hidden_layers=2, + num_attention_heads=8, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + eod_token_id=3, + sep_token_id=4, + tie_word_embeddings=False, + attention_bias=False, + attention_dropout=0.0, + num_experts=2, + moe_topk=1, + attn_implementation="sdpa", + ), + ) + +if EXAONE4_AVAILABLE: + MINI_MODEL_SETUPS["mini_exaone4"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_exaone4, + liger_kernel_patch_revert_func=revert_liger_kernel_to_exaone4, + model_class=Exaone4ForCausalLM, + mini_model_config=Exaone4Config( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=896, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=32768, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-5, + tie_word_embeddings=True, + use_cache=True, + vocab_size=32000, + attn_implementation="sdpa", + pad_token_id=None, + ), + ) + + +def create_model(model_name="mini_llama3"): + """ + Create a mini version model + The commented values are the original values + """ + model_config = MINI_MODEL_SETUPS[model_name].mini_model_config + model_class = MINI_MODEL_SETUPS[model_name].model_class + return model_class(model_config) + + +@require_deterministic +def run_mini_model( + model_name="mini_llama3", + num_steps=100, + dtype=torch.bfloat16, + lr=1e-5, + with_liger=False, +): + # If we move it to the beginning of test_mini_model, the two runs are initialized with different weights. + # This is due to RNG (Random Number Generator). The formula of RNG progression is x_(n+1) = (a * x_n + c) % m + # Everytime RNG is used, like randomly initialzing weight, the RNG progresses to the next state. + # Therefore, we have to reset RNG before we create the model to ensure the weight initialization started from the same RNG state. + + set_seed(42) + + revert_kwargs = {"model_config": MINI_MODEL_SETUPS[model_name]} + if "mllama" in model_name: + revert_kwargs["model_type"] = "causal_lm" + + if with_liger is True: + kwargs = { + "rope": True, + "rms_norm": True, + } + + if "glm4" in model_name or "llama4" in model_name or "qwen3_next" in model_name or "qwen3_5" in model_name: + kwargs["rope"] = False + + model_supports_layer_norm = "qwen2_vl" in model_name + if model_supports_layer_norm: + kwargs["layer_norm"] = True + + if "gemma" in model_name: + kwargs["geglu"] = True + else: + kwargs["swiglu"] = True + + if "llava" in model_name: + apply_liger_kernel_to_llama(**kwargs) + + kwargs["fused_linear_cross_entropy"] = False + kwargs["cross_entropy"] = False + + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) + else: + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs) + + model = create_model(model_name).to(dtype).to(device) + + train_dataset = load_from_disk(DEFAULT_DATASET_PATH) + loader = DataLoader(train_dataset, batch_size=16, shuffle=False, collate_fn=simple_collate_fn) + loader_iter = iter(loader) + optimizer = torch.optim.AdamW(model.parameters(), lr=lr) + + loss_list = [] + + for i in range(num_steps): + batch = next(loader_iter).to(model.device) + optimizer.zero_grad() + output = model(**batch) + output.loss.backward() + optimizer.step() + print(f"Step {i}, Loss: {output.loss.item()}") + loss_list.append(output.loss.item()) + + topk_logprobs = get_topk(get_logprobs(output.logits)) + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs) + return { + "loss": loss_list, + "topk_logprobs": topk_logprobs.values, + "model": model, + } + + +@pytest.mark.parametrize( + "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logprobs_atol, logprobs_rtol, param_atol, param_rtol", + [ + # Tolerance is set higher than usual to pass the tests. + pytest.param( + "mini_llama4", # llama4 requires slightly larger tolerances to pass this test after bug fix to llama4 in transformers v5.0.0 + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 4e-1, + 3e-1, + 2e-1, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not LLAMA4_AVAILABLE, + reason="Llama4 not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_llama3", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + ), + pytest.param( + "mini_llava", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not LLAVA_AVAILABLE, + reason="LLaVa not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_granite3", + 32, + 1e-5, + torch.bfloat16, + 1e-2, # loss + 1e-2, # loss + 1e-1, # logit logprobs atol + 1e-2, # logprobs rtol + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not GRANITE_AVAILABLE, + reason="Granite not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_mllama", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not MLLAMA_AVAILABLE, + reason="Mllama not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_qwen2", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + ), + pytest.param( + "mini_qwen3", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not QWEN3_AVAILABLE, + reason="Qwen3 not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_qwen3_moe", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 2e-1, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not QWEN3_AVAILABLE, + reason="Qwen3 not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_qwen3_vl", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not QWEN3_VL_AVAILABLE, + reason="Qwen3-VL not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_qwen3_vl_moe", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not QWEN3_VL_MOE_AVAILABLE, + reason="Qwen3-VL-MoE not available in this version of transformers", + ), + pytest.mark.skipif(True, reason="Flaky test"), + ], + ), + pytest.param( + "mini_qwen2_vl", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not QWEN2_VL_AVAILABLE, + reason="Qwen2-VL not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_qwen2_5_vl", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not QWEN2_5_VL_AVAILABLE, + reason="Qwen2.5-VL not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_phi3", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + ), + pytest.param( + "mini_mistral", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + ), + # TODO: mixtral is flaky so disable the test for now + # pytest.param( + # "mini_mixtral", + # 32, + # 1e-4, + # torch.bfloat16, + # 1e-3, + # 1e-2, + # 1e-1, + # 1e-2, + # 1e-1, + # 1e-2, + # marks=pytest.mark.skipif( + # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + # ), + # ), + # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match + pytest.param( + "mini_gemma1", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + ), + pytest.param( + "mini_gemma1.1", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + ), + pytest.param( + "mini_olmo2", + 32, + 1e-4, + torch.bfloat16, + 1e-2, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not OLMO2_AVAILABLE, + reason="OLMO2 not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_olmo3", + 32, + 1e-4, + torch.bfloat16, + 1e-2, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not OLMO3_AVAILABLE, + reason="OLMO3 not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_glm4", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not GLM4_AVAILABLE, + reason="Glm4 not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_glm4v", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 2e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not GLM4V_AVAILABLE, + reason="Glm4v not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_glm4v_moe", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 4e-1, # rms_norm patch needs higher tolerance in bf16 + 1e-1, + 5e-1, # rms_norm patch needs higher tolerance in bf16 + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not GLM4V_MOE_AVAILABLE, + reason="Glm4v_moe not available in this version of transformers", + ), + ], + ), + # TODO: Gemma2 test for bf16 is not passing within the tolerance range, might be casting issue, need to investigate + # pytest.param( + # "mini_gemma2", + # 32, + # 1e-4, + # torch.bfloat16, + # 1e-3, + # 1e-2, + # 1e-1, + # 1e-2, + # 1e-2, + # 1e-2, + # marks=pytest.mark.skipif( + # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + # ), + # ), + pytest.param( + "mini_gemma3_text", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 5e-2, + 3e-1, # 1e-1 too flaky + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not GEMMA3_AVAILABLE, + reason="Gemma3 not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_smollm3", + 32, + 1e-4, + torch.bfloat16, + 1e-3, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not SMOLLM3_AVAILABLE, + reason="Smollm3 not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_internvl", + 32, + 1e-4, + torch.bfloat16, + 1e-3, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not INTERNVL_AVAILABLE, + reason="InternVL not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_falcon_h1", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not FALCONH1_AVAILABLE, + reason="FalconH1 not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_qwen3_next", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 1e-2, + 1e-1, + 1e-1, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not QWEN3NEXT_AVAILABLE, + reason="Qwen3Next not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_qwen3_5", + 32, + 1e-5, + torch.bfloat16, + 5e-2, + 2e-1, + 1e-1, + 1e-1, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not QWEN3_5_AVAILABLE, + reason="Qwen3_5 not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_hunyuan_v1", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not HUNYUAN_V1_AVAILABLE, + reason="Hunyuan_v1 not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_hunyuan_v1_moe", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not HUNYUAN_V1_AVAILABLE, + reason="Hunyuan_v1_moe not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_exaone4", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not EXAONE4_AVAILABLE, + reason="EXAONE4 not available in this version of transformers", + ), + ], + ), + ], +) +def test_mini_model( + model_name, + num_steps, + lr, + dtype, + loss_atol, + loss_rtol, + logprobs_atol, + logprobs_rtol, + param_atol, + param_rtol, +): + # Non-liger models should be initialized and tested first to avoid the module being overridden + + expected_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr) + + actual_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True) + + # Compare every step of the loss + assert_verbose_allclose( + torch.tensor([expected_output["loss"]]), + torch.tensor([actual_output["loss"]]), + atol=loss_atol, + rtol=loss_rtol, + extra_info="[Loss]", + ) + + # Compare the topk logprobs from evaluation step + assert_verbose_allclose( + expected_output["topk_logprobs"], + actual_output["topk_logprobs"], + atol=logprobs_atol, + rtol=logprobs_rtol, + extra_info="[Top k logprobs]", + ) + + # Compare the params from the last step + # Iterate over the model's parameters and compare them + for expected_param, actual_param in zip( + expected_output["model"].named_parameters(), + actual_output["model"].named_parameters(), + ): + assert_verbose_allclose( + expected_param[1], + actual_param[1], + atol=param_atol, + rtol=param_rtol, + extra_info="[Model parameters]", + ) diff --git a/test/convergence/fp32/__init__.py b/test/convergence/fp32/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/test/convergence/fp32/test_mini_models.py b/test/convergence/fp32/test_mini_models.py new file mode 100755 index 0000000000000000000000000000000000000000..1a09cba21352152470c64d8237e6007d4e8c64b9 --- /dev/null +++ b/test/convergence/fp32/test_mini_models.py @@ -0,0 +1,2170 @@ +import os + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # Ensure deterministic behavior with CuBLAS + +import pytest +import torch +import transformers + +from datasets import load_from_disk +from packaging import version +from torch.utils.data import DataLoader +from transformers.models.gemma import GemmaConfig +from transformers.models.gemma import GemmaForCausalLM +from transformers.models.gemma2 import Gemma2Config +from transformers.models.gemma2 import Gemma2ForCausalLM +from transformers.models.llama import LlamaConfig +from transformers.models.llama import LlamaForCausalLM +from transformers.models.mistral import MistralConfig +from transformers.models.mistral import MistralForCausalLM +from transformers.models.mixtral import MixtralConfig +from transformers.models.mixtral import MixtralForCausalLM +from transformers.models.phi3 import Phi3Config +from transformers.models.phi3 import Phi3ForCausalLM +from transformers.models.qwen2 import Qwen2Config +from transformers.models.qwen2 import Qwen2ForCausalLM + +from liger_kernel.transformers import apply_liger_kernel_to_exaone4 +from liger_kernel.transformers import apply_liger_kernel_to_falcon_h1 +from liger_kernel.transformers import apply_liger_kernel_to_gemma +from liger_kernel.transformers import apply_liger_kernel_to_gemma2 +from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text +from liger_kernel.transformers import apply_liger_kernel_to_glm4 +from liger_kernel.transformers import apply_liger_kernel_to_glm4v +from liger_kernel.transformers import apply_liger_kernel_to_glm4v_moe +from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss +from liger_kernel.transformers import apply_liger_kernel_to_granite +from liger_kernel.transformers import apply_liger_kernel_to_hunyuan_v1_dense +from liger_kernel.transformers import apply_liger_kernel_to_hunyuan_v1_moe +from liger_kernel.transformers import apply_liger_kernel_to_internvl +from liger_kernel.transformers import apply_liger_kernel_to_llama +from liger_kernel.transformers import apply_liger_kernel_to_llama4 +from liger_kernel.transformers import apply_liger_kernel_to_llava +from liger_kernel.transformers import apply_liger_kernel_to_mistral +from liger_kernel.transformers import apply_liger_kernel_to_mixtral +from liger_kernel.transformers import apply_liger_kernel_to_mllama +from liger_kernel.transformers import apply_liger_kernel_to_olmo2 +from liger_kernel.transformers import apply_liger_kernel_to_olmo3 +from liger_kernel.transformers import apply_liger_kernel_to_phi3 +from liger_kernel.transformers import apply_liger_kernel_to_qwen2 +from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl +from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl +from liger_kernel.transformers import apply_liger_kernel_to_qwen3 +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5 +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5_moe +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_moe +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_next +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl_moe +from liger_kernel.transformers import apply_liger_kernel_to_smollm3 +from liger_kernel.utils import infer_device +from test.utils import DEFAULT_DATASET_PATH +from test.utils import MiniModelConfig +from test.utils import assert_verbose_allclose +from test.utils import get_logprobs +from test.utils import get_topk +from test.utils import require_deterministic +from test.utils import revert_liger_kernel_to_exaone4 +from test.utils import revert_liger_kernel_to_falcon_h1 +from test.utils import revert_liger_kernel_to_gemma +from test.utils import revert_liger_kernel_to_gemma2 +from test.utils import revert_liger_kernel_to_gemma3_text +from test.utils import revert_liger_kernel_to_glm4 +from test.utils import revert_liger_kernel_to_glm4v +from test.utils import revert_liger_kernel_to_glm4v_moe +from test.utils import revert_liger_kernel_to_gpt_oss +from test.utils import revert_liger_kernel_to_granite +from test.utils import revert_liger_kernel_to_hunyuan_v1 +from test.utils import revert_liger_kernel_to_hunyuan_v1_moe +from test.utils import revert_liger_kernel_to_internvl +from test.utils import revert_liger_kernel_to_llama +from test.utils import revert_liger_kernel_to_llama4 +from test.utils import revert_liger_kernel_to_llava +from test.utils import revert_liger_kernel_to_mistral +from test.utils import revert_liger_kernel_to_mixtral +from test.utils import revert_liger_kernel_to_mllama +from test.utils import revert_liger_kernel_to_olmo2 +from test.utils import revert_liger_kernel_to_olmo3 +from test.utils import revert_liger_kernel_to_phi3 +from test.utils import revert_liger_kernel_to_qwen2 +from test.utils import revert_liger_kernel_to_qwen2_5_vl +from test.utils import revert_liger_kernel_to_qwen2_vl +from test.utils import revert_liger_kernel_to_qwen3 +from test.utils import revert_liger_kernel_to_qwen3_5 +from test.utils import revert_liger_kernel_to_qwen3_5_moe +from test.utils import revert_liger_kernel_to_qwen3_moe +from test.utils import revert_liger_kernel_to_qwen3_next +from test.utils import revert_liger_kernel_to_qwen3_vl +from test.utils import revert_liger_kernel_to_qwen3_vl_moe +from test.utils import revert_liger_kernel_to_smollm3 +from test.utils import set_seed +from test.utils import simple_collate_fn + +IS_TRANSFORMERS_V5_OR_LATER = version.parse(transformers.__version__) >= version.parse("5.0.0") + +try: + from transformers.models.llama4.configuration_llama4 import Llama4TextConfig + from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM + + LLAMA4_AVAILABLE = True +except ImportError: + LLAMA4_AVAILABLE = False + +try: + # Mllama is only available in transformers>=4.45.0 + from transformers.models.mllama.configuration_mllama import MllamaTextConfig + from transformers.models.mllama.modeling_mllama import MllamaForCausalLM + + MLLAMA_AVAILABLE = True +except ImportError: + MLLAMA_AVAILABLE = False + +try: + # Qwen2-VL is only available in transformers>4.52.4 + import transformers + + from packaging import version + from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration + + QWEN2_VL_AVAILABLE = version.parse(transformers.__version__) >= version.parse("4.52.4") +except ImportError: + QWEN2_VL_AVAILABLE = False + +try: + # Qwen2.5-VL is only available in transformers>4.52.4 + import transformers + + from packaging import version + from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration + + QWEN2_5_VL_AVAILABLE = version.parse(transformers.__version__) >= version.parse("4.52.4") +except ImportError: + QWEN2_5_VL_AVAILABLE = False + + +try: + # Qwen3-VL is only available in transformers>=4.57.0 + import transformers + + from packaging import version + from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration + + QWEN3_VL_AVAILABLE = version.parse(transformers.__version__) >= version.parse("4.57.0") +except ImportError: + QWEN3_VL_AVAILABLE = False + + +try: + # Qwen3-VL-MoE is only available in transformers>=4.57.0 + import transformers + + from packaging import version + from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeConfig + from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeTextConfig + from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeVisionConfig + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration + + QWEN3_VL_MOE_AVAILABLE = version.parse(transformers.__version__) >= version.parse("4.57.0") +except ImportError: + QWEN3_VL_MOE_AVAILABLE = False + +try: + from transformers.models.granite import GraniteConfig + from transformers.models.granite import GraniteForCausalLM + + GRANITE_AVAILABLE = True +except ImportError: + GRANITE_AVAILABLE = False + +try: + from transformers import CLIPVisionConfig + from transformers.models.llava.configuration_llava import LlavaConfig + from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration + + LLAVA_AVAILABLE = True +except ImportError: + LLAVA_AVAILABLE = False + +try: + # OLMO2 is only available in transformers>=4.47.0 + from transformers.models.olmo2.configuration_olmo2 import Olmo2Config + from transformers.models.olmo2.modeling_olmo2 import Olmo2ForCausalLM + + OLMO2_AVAILABLE = True +except ImportError: + OLMO2_AVAILABLE = False + +try: + # OLMO3 is only available in transformers>=4.57.0 + from transformers.models.olmo3.configuration_olmo3 import Olmo3Config + from transformers.models.olmo3.modeling_olmo3 import Olmo3ForCausalLM + + OLMO3_AVAILABLE = True +except ImportError: + OLMO3_AVAILABLE = False + +try: + # Glm4 is only available in transformers>=4.51.3 + from transformers.models.glm4.configuration_glm4 import Glm4Config + from transformers.models.glm4.modeling_glm4 import Glm4ForCausalLM + + GLM4_AVAILABLE = True +except ImportError: + GLM4_AVAILABLE = False + +try: + # Glm4v is only available in transformers>=4.51.3 + from transformers.models.glm4v.configuration_glm4v import Glm4vConfig + from transformers.models.glm4v.modeling_glm4v import Glm4vForConditionalGeneration + + GLM4V_AVAILABLE = True +except ImportError: + GLM4V_AVAILABLE = False + +try: + # Glm4v_moe is only available in transformers>=4.51.3 + from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig + from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeForConditionalGeneration + + GLM4V_MOE_AVAILABLE = True +except ImportError: + GLM4V_MOE_AVAILABLE = False + +try: + from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig + from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM + + GEMMA3_AVAILABLE = True +except ImportError: + GEMMA3_AVAILABLE = False + +try: + # Smollm3 is only available in transformers>=4.53.0 + from transformers.models.smollm3.configuration_smollm3 import SmolLM3Config + from transformers.models.smollm3.modeling_smollm3 import SmolLM3ForCausalLM + + SMOLLM3_AVAILABLE = True +except ImportError: + SMOLLM3_AVAILABLE = False + +try: + from transformers.models.qwen3.configuration_qwen3 import Qwen3Config + from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM + from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig + from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeForCausalLM + + QWEN3_AVAILABLE = True +except ImportError: + QWEN3_AVAILABLE = False + +try: + # GPT-OSS is only available in transformers>=4.55.0 + from transformers.models.gpt_oss.configuration_gpt_oss import GptOssConfig + from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM + + GPT_OSS_AVAILABLE = True +except ImportError: + GPT_OSS_AVAILABLE = False + +try: + # InternVL is only available in transformers>=4.52.1 + from transformers.models.internvl.configuration_internvl import InternVLConfig + from transformers.models.internvl.modeling_internvl import InternVLForConditionalGeneration + + INTERNVL_AVAILABLE = True +except ImportError: + INTERNVL_AVAILABLE = False + +try: + # FalconH1 is only available in transformers>=4.53.0 + from transformers.models.falcon_h1.configuration_falcon_h1 import FalconH1Config + from transformers.models.falcon_h1.modeling_falcon_h1 import FalconH1ForCausalLM + + FALCONH1_AVAILABLE = True +except ImportError: + FALCONH1_AVAILABLE = False + +try: + # Qwen3Next is only available in transformers>=4.57.0 + from transformers.models.qwen3_next.configuration_qwen3_next import Qwen3NextConfig + from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextForCausalLM + + QWEN3NEXT_AVAILABLE = True +except ImportError: + QWEN3NEXT_AVAILABLE = False + +try: + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeForCausalLM + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeTextConfig + + QWEN3_5_MOE_AVAILABLE = True +except ImportError: + QWEN3_5_MOE_AVAILABLE = False + +try: + from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM + + QWEN3_5_AVAILABLE = True +except ImportError: + QWEN3_5_AVAILABLE = False + +try: + from transformers.models.hunyuan_v1_dense.configuration_hunyuan_v1_dense import HunYuanDenseV1Config + from transformers.models.hunyuan_v1_dense.modeling_hunyuan_v1_dense import HunYuanDenseV1ForCausalLM + from transformers.models.hunyuan_v1_moe.configuration_hunyuan_v1_moe import HunYuanMoEV1Config + from transformers.models.hunyuan_v1_moe.modeling_hunyuan_v1_moe import HunYuanMoEV1ForCausalLM + + HUNYUAN_V1_AVAILABLE = True +except ImportError: + HUNYUAN_V1_AVAILABLE = False + +try: + from transformers.models.exaone4.configuration_exaone4 import Exaone4Config + from transformers.models.exaone4.modeling_exaone4 import Exaone4ForCausalLM + + EXAONE4_AVAILABLE = True +except ImportError: + EXAONE4_AVAILABLE = False + + +device = infer_device() + +MINI_MODEL_SETUPS = { + "mini_llama3": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_llama, + liger_kernel_patch_revert_func=revert_liger_kernel_to_llama, + model_class=LlamaForCausalLM, + mini_model_config=LlamaConfig( + attention_bias=False, + attention_dropout=0.0, + # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=8192, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + num_key_value_heads=2, # 8 + pretraining_tp=1, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ), + "mini_qwen2": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen2, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen2, + model_class=Qwen2ForCausalLM, + mini_model_config=Qwen2Config( + attention_dropout=0.0, + bos_token_id=1, # 151643 + eos_token_id=2, # 151643 + hidden_act="silu", + hidden_size=896, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=32768, # 131072 + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-6, + sliding_window=131072, + tie_word_embeddings=True, + use_cache=True, + vocab_size=32000, # 151936 + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ), + "mini_phi3": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_phi3, + liger_kernel_patch_revert_func=revert_liger_kernel_to_phi3, + model_class=Phi3ForCausalLM, + mini_model_config=Phi3Config( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, # 32000 + hidden_act="silu", + hidden_size=896, # 3072 + initializer_range=0.02, + intermediate_size=4864, # 8192 + max_position_embeddings=4096, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + num_key_value_heads=None, # defaults to num_attention_heads + rms_norm_eps=1e-5, + sliding_window=None, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32064, + attn_implementation="eager", + ), + ), + "mini_mistral": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_mistral, + liger_kernel_patch_revert_func=revert_liger_kernel_to_mistral, + model_class=MistralForCausalLM, + mini_model_config=MistralConfig( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=2048, + max_position_embeddings=32768, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-5, + sliding_window=4096, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, + attn_implementation="sdpa", + ), + ), + "mini_mixtral": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_mixtral, + liger_kernel_patch_revert_func=revert_liger_kernel_to_mixtral, + model_class=MixtralForCausalLM, + mini_model_config=MixtralConfig( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=512, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=32768, # 32768 + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-5, + sliding_window=4096, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, + attn_implementation="sdpa", + ), + ), + "mini_gemma1": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_gemma, + liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma, + model_class=GemmaForCausalLM, + mini_model_config=GemmaConfig( + vocab_size=32000, # 256000 + hidden_size=1024, # 3072 + intermediate_size=2048, # 24576 + num_hidden_layers=4, # 28 + num_attention_heads=4, # 16 + num_key_value_heads=4, # 16 + head_dim=256, + # gemma1 model config uses `hidden_act` and point it to gelu, + # https://huggingface.co/google/gemma-7b/blob/main/config.json#L10 + # but in reality it's ignored and HuggingFace will use tanh approximation: + # https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/gemma/modeling_gemma.py#L175 + hidden_act="gelu", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + pad_token_id=0, + # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + tie_word_embeddings=True, + attention_bias=False, + attention_dropout=0.0, + ), + ), + "mini_gemma1.1": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_gemma, + liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma, + model_class=GemmaForCausalLM, + mini_model_config=GemmaConfig( + vocab_size=32000, # 256000 + hidden_size=1024, # 3072 + intermediate_size=2048, # 24576 + num_hidden_layers=4, # 28 + num_attention_heads=4, # 16 + num_key_value_heads=4, # 16 + head_dim=256, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + pad_token_id=0, + # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + tie_word_embeddings=True, + attention_bias=False, + attention_dropout=0.0, + ), + ), + "mini_gemma2": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_gemma2, + liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma2, + model_class=Gemma2ForCausalLM, + mini_model_config=Gemma2Config( + vocab_size=32000, # 256000 + hidden_size=1024, # 3072 + intermediate_size=2048, # 24576 + num_hidden_layers=4, # 28 + num_attention_heads=4, # 16 + num_key_value_heads=4, # 16 + head_dim=256, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + pad_token_id=0, + # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + tie_word_embeddings=True, + attention_bias=False, + attention_dropout=0.0, + attn_implementation="eager", + ), + ), +} +if LLAMA4_AVAILABLE: + MINI_MODEL_SETUPS["mini_llama4"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_llama4, + liger_kernel_patch_revert_func=revert_liger_kernel_to_llama4, + model_class=Llama4ForCausalLM, + mini_model_config=Llama4TextConfig( + bos_token_id=1, # None + eos_token_id=2, # 151329, 151336, 151338 + pad_token_id=2, # 151329 + partial_rotary_factor=1.0, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 6144 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, # 32768 + num_attention_heads=8, # 48 + num_hidden_layers=4, # 61 + num_key_value_heads=2, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 151552 + attention_bias=True, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + + +if QWEN3_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3, + model_class=Qwen3ForCausalLM, + mini_model_config=Qwen3Config( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=896, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=32768, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-6, + sliding_window=131072, + tie_word_embeddings=True, + use_cache=True, + vocab_size=32000, + attn_implementation="sdpa", + ), + ) + + MINI_MODEL_SETUPS["mini_qwen3_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_moe, + model_class=Qwen3MoeForCausalLM, + mini_model_config=Qwen3MoeConfig( + vocab_size=32000, # 151936 + hidden_size=896, + intermediate_size=4864, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + attention_bias=False, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + decoder_sparse_step=1, + moe_intermediate_size=768, + num_experts_per_tok=2, + num_experts=8, + norm_topk_prob=False, + output_router_logits=False, + router_aux_loss_coef=0.001, + mlp_only_layers=None, + ), + ) + +if GPT_OSS_AVAILABLE: + MINI_MODEL_SETUPS["mini_gpt_oss"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_gpt_oss, + liger_kernel_patch_revert_func=revert_liger_kernel_to_gpt_oss, + model_class=GptOssForCausalLM, + mini_model_config=GptOssConfig( + vocab_size=32000, # 201088 + hidden_size=896, + intermediate_size=896, # Same as hidden_size for GPT-OSS + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + head_dim=64, + hidden_act="silu", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=False, + attention_dropout=0.0, + num_local_experts=8, # Reduced from 32 for mini model + num_experts_per_tok=2, # Reduced from 4 for mini model + router_aux_loss_coef=0.9, + output_router_logits=False, + sliding_window=128, + layer_types=["sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(4)], + ), + ) + +if GEMMA3_AVAILABLE: + MINI_MODEL_SETUPS["mini_gemma3_text"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_gemma3_text, + liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma3_text, + model_class=Gemma3ForCausalLM, + mini_model_config=Gemma3TextConfig( + vocab_size=32000, # 262144 + hidden_size=1024, # 1152 + intermediate_size=2048, # 6912 + num_hidden_layers=4, # 26 + num_attention_heads=4, + num_key_value_heads=1, + head_dim=256, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=8192, # 32768 + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + pad_token_id=0, + bos_token_id=2, + eos_token_id=1, + tie_word_embeddings=True, + attention_bias=False, + attention_dropout=0.0, + attn_implementation="eager", + ), + ) + +if MLLAMA_AVAILABLE: + MINI_MODEL_SETUPS["mini_mllama"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_mllama, + liger_kernel_patch_revert_func=revert_liger_kernel_to_mllama, + model_class=MllamaForCausalLM, + mini_model_config=MllamaTextConfig( + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + pad_token_id=2, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=131_072, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 40 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + attn_implementation="sdpa", # default value, pytorch native attention + rope_scaling=dict( + factor=8.0, + high_freq_factor=4.0, + low_freq_factor=1.0, + original_max_position_embeddings=8192, + rope_type="llama3", + rope_theta=500_000, + ) + if not IS_TRANSFORMERS_V5_OR_LATER + else None, + ), + ) + +if QWEN2_VL_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen2_vl"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen2_vl, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen2_vl, + model_class=Qwen2VLForConditionalGeneration, + mini_model_config=Qwen2VLConfig( + # In transformers v5, text-related parameters must be in text_config + text_config={ + "attention_dropout": 0.0, + # bos and eos set to match the Mistral-7B tokenizer used to create the test dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + "bos_token_id": 1, # 151643 + "eos_token_id": 2, # 151645 + "hidden_act": "silu", + "hidden_size": 1536, # 8192 + "initializer_range": 0.02, + "intermediate_size": 4864, # 29568 + "max_position_embeddings": 32768, + "max_window_layers": 4, # 80 + "num_attention_heads": 12, # 64 + "num_hidden_layers": 4, # 80 + "num_key_value_heads": 2, # 8 + "rms_norm_eps": 1e-6, # 1e-5 + **( + {"rope_parameters": {"mrope_section": [16, 24, 24]}} # (temporal, height, width) + if IS_TRANSFORMERS_V5_OR_LATER + else {"rope_scaling": {"type": "mrope", "mrope_section": [16, 24, 24]}} + ), + "sliding_window": 4096, + "tie_word_embeddings": False, + "use_cache": True, + "vocab_size": 32768, # 152064 # >32k, Mistral-7B tokenizer vocab size + "use_sliding_window": False, + }, + vision_start_token_id=32765, # vocab_size - 5 + vision_end_token_id=32766, # vocab_size - 4 + image_token_id=32768, # vocab_size - 2 + video_token_id=32769, # vocab_size - 1 + vision_config={ + "depth": 4, # 32 + "embed_dim": 1280, + "mlp_ratio": 4, + "num_heads": 16, + "in_chans": 3, + "hidden_size": 128, # 1536 + "patch_size": 14, + "spatial_merge_size": 2, + "spatial_patch_size": 14, + "temporal_patch_size": 2, + }, + ), + ) + +if QWEN2_5_VL_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen2_5_vl"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen2_5_vl, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen2_5_vl, + model_class=Qwen2_5_VLForConditionalGeneration, + mini_model_config=Qwen2_5_VLConfig( + # In transformers v5, text-related parameters must be in text_config + text_config={ + "attention_dropout": 0.0, + # bos and eos set to match the Mistral-7B tokenizer used to create the test dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + "bos_token_id": 1, # 151643 + "eos_token_id": 2, # 151645 + "hidden_act": "silu", + "hidden_size": 1536, # 8192 + "initializer_range": 0.02, + "intermediate_size": 4864, # 29568 + "max_position_embeddings": 32768, + "max_window_layers": 4, # 80 + "num_attention_heads": 12, # 64 + "num_hidden_layers": 4, # 80 + "num_key_value_heads": 2, # 8 + "rms_norm_eps": 1e-6, # 1e-5 + **( + {"rope_parameters": {"mrope_section": [16, 24, 24]}} # (temporal, height, width) + if IS_TRANSFORMERS_V5_OR_LATER + else {"rope_scaling": {"type": "mrope", "mrope_section": [16, 24, 24]}} + ), + "sliding_window": 4096, + "tie_word_embeddings": False, + "use_cache": True, + "vocab_size": 32768, # 152064 # >32k, Mistral-7B tokenizer vocab size + "use_sliding_window": False, + }, + vision_start_token_id=32765, # vocab_size - 5 + vision_end_token_id=32766, # vocab_size - 4 + image_token_id=32768, # vocab_size - 2 + video_token_id=32769, # vocab_size - 1 + vision_config={ + "depth": 4, # 32 + "hidden_act": "silu", + "hidden_size": 128, # 1280 + "intermediate_size": 256, # 3420 + "num_heads": 16, + "in_chans": 3, + "out_hidden_size": 128, # 3584 + "patch_size": 14, + "spatial_merge_size": 2, + "spatial_patch_size": 14, + "window_size": 112, + "fullatt_block_indexes": [7, 15, 23, 31], + "tokens_per_second": 2, + "temporal_patch_size": 2, + }, + ), + ) + +if QWEN3_VL_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3_vl"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_vl, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_vl, + model_class=Qwen3VLForConditionalGeneration, + mini_model_config=Qwen3VLConfig( + bos_token_id=1, + eos_token_id=2, + vision_start_token_id=32765, + vision_end_token_id=32766, + image_token_id=32768, + video_token_id=32769, + tie_word_embeddings=False, + attn_implementation="sdpa", + text_config=dict( + attention_dropout=0.0, + hidden_act="silu", + hidden_size=1536, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=32768, + num_attention_heads=12, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-6, + use_cache=True, + vocab_size=32768, + rope_scaling=dict( + type="mrope", + mrope_section=[16, 24, 24], # (temporal, height, width) + ) + if not IS_TRANSFORMERS_V5_OR_LATER + else None, + ), + vision_config=dict( + depth=4, + hidden_size=128, + hidden_act="silu", + intermediate_size=256, + num_heads=8, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=128, + num_position_embeddings=256, + deepstack_visual_indexes=[], + initializer_range=0.02, + ), + ), + ) + +if QWEN3_VL_MOE_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3_vl_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_vl_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_vl_moe, + model_class=Qwen3VLMoeForConditionalGeneration, + mini_model_config=Qwen3VLMoeConfig( + bos_token_id=1, + eos_token_id=2, + vision_start_token_id=32765, + vision_end_token_id=32766, + image_token_id=32768, + video_token_id=32769, + tie_word_embeddings=False, + attn_implementation="sdpa", + text_config=Qwen3VLMoeTextConfig( + attention_dropout=0.0, + attention_bias=False, + hidden_act="silu", + hidden_size=1536, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=32768, + num_attention_heads=12, + num_hidden_layers=4, + num_key_value_heads=2, + head_dim=128, + rms_norm_eps=1e-6, + use_cache=True, + vocab_size=32768, + decoder_sparse_step=1, + moe_intermediate_size=3072, + num_experts_per_tok=2, + num_experts=4, + tie_word_embeddings=False, + mlp_only_layers=[], + pad_token_id=None, + rope_scaling=dict( + type="mrope", + mrope_section=[16, 24, 24], # (temporal, height, width) + ) + if not IS_TRANSFORMERS_V5_OR_LATER + else None, + ).to_dict(), + vision_config=Qwen3VLMoeVisionConfig( + depth=4, + hidden_size=128, + hidden_act="gelu_pytorch_tanh", + intermediate_size=256, + num_heads=8, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=128, + num_position_embeddings=256, + deepstack_visual_indexes=[1, 2, 3], + initializer_range=0.02, + ).to_dict(), + ), + ) + +if GRANITE_AVAILABLE: + MINI_MODEL_SETUPS["mini_granite3"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_granite, + liger_kernel_patch_revert_func=revert_liger_kernel_to_granite, + model_class=GraniteForCausalLM, + mini_model_config=GraniteConfig( + attention_bias=False, + attention_dropout=0.1, + # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=8192, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + num_key_value_heads=2, # 8 + pretraining_tp=1, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if OLMO2_AVAILABLE: + MINI_MODEL_SETUPS["mini_olmo2"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_olmo2, + liger_kernel_patch_revert_func=revert_liger_kernel_to_olmo2, + model_class=Olmo2ForCausalLM, + mini_model_config=Olmo2Config( + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + pad_token_id=2, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 40 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if OLMO3_AVAILABLE: + MINI_MODEL_SETUPS["mini_olmo3"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_olmo3, + liger_kernel_patch_revert_func=revert_liger_kernel_to_olmo3, + model_class=Olmo3ForCausalLM, + mini_model_config=Olmo3Config( + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + pad_token_id=2, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 40 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if GLM4_AVAILABLE: + MINI_MODEL_SETUPS["mini_glm4"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_glm4, + liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4, + model_class=Glm4ForCausalLM, + mini_model_config=Glm4Config( + bos_token_id=1, # None + eos_token_id=2, # 151329, 151336, 151338 + pad_token_id=2, # 151329 + partial_rotary_factor=0.5, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 6144 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=32768, + num_attention_heads=8, # 48 + num_hidden_layers=4, # 61 + num_key_value_heads=2, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 151552 + attention_bias=True, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if GLM4V_AVAILABLE: + MINI_MODEL_SETUPS["mini_glm4v"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_glm4v, + liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4v, + model_class=Glm4vForConditionalGeneration, + mini_model_config=Glm4vConfig( + bos_token_id=1, # None + eos_token_id=2, # 151329, 151336, 151338 + pad_token_id=2, # 151329 + image_token_id=151343, + video_token_id=151344, + image_start_token_id=151339, + image_end_token_id=151340, + video_start_token_id=151341, + video_end_token_id=151342, + partial_rotary_factor=0.5, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 6144 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, # 32768 + num_attention_heads=8, # 48 + num_hidden_layers=4, # 61 + num_key_value_heads=2, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 151552 + attention_bias=True, + attn_implementation="sdpa", # default value, pytorch native attention + text_config={ + "partial_rotary_factor": 0.5, + "hidden_act": "silu", + "hidden_size": 1024, + "intermediate_size": 2048, + "max_position_embeddings": 4096, + "num_attention_heads": 8, + "num_hidden_layers": 4, + "num_key_value_heads": 2, + "rms_norm_eps": 1e-5, + "vocab_size": 32000, + "attention_bias": True, + **( + {"rope_scaling": {"type": "default", "mrope_section": [8, 12, 12]}} + if not IS_TRANSFORMERS_V5_OR_LATER + else {} + ), + "pad_token_id": None, + }, + vision_config={ + "depth": 4, # 32 + "hidden_act": "silu", + "hidden_size": 128, # 1280 + "intermediate_size": 256, # 3420 + "num_heads": 16, + "in_chans": 3, + "out_hidden_size": 128, # 3584 + "patch_size": 14, + "spatial_merge_size": 2, + "temporal_patch_size": 2, + }, + ), + ) + +if GLM4V_MOE_AVAILABLE: + MINI_MODEL_SETUPS["mini_glm4v_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_glm4v_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4v_moe, + model_class=Glm4vMoeForConditionalGeneration, + mini_model_config=Glm4vMoeConfig( + bos_token_id=1, # None + eos_token_id=2, # 151329, 151336, 151338 + pad_token_id=2, # 151329 + image_token_id=151343, + video_token_id=151344, + image_start_token_id=151339, + image_end_token_id=151340, + video_start_token_id=151341, + video_end_token_id=151342, + partial_rotary_factor=0.5, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 6144 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, # 32768 + num_attention_heads=8, # 48 + num_hidden_layers=4, # 61 + num_key_value_heads=2, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 151552 + attention_bias=True, + attn_implementation="sdpa", # default value, pytorch native attention + text_config={ + "partial_rotary_factor": 0.5, + "hidden_act": "silu", + "hidden_size": 1024, + "intermediate_size": 2048, + "max_position_embeddings": 4096, + "num_attention_heads": 8, + "num_hidden_layers": 4, + "num_key_value_heads": 2, + "rms_norm_eps": 1e-5, + "vocab_size": 32000, + "attention_bias": True, + "attention_dropout": 0.0, + "moe_intermediate_size": 1408, + "num_experts_per_tok": 2, + "n_shared_experts": 1, + "n_routed_experts": 8, + "routed_scaling_factor": 1.0, + "n_group": 1, + "topk_group": 1, + "first_k_dense_replace": 1, + "norm_topk_prob": True, + **( + {"rope_scaling": {"type": "default", "mrope_section": [8, 12, 12]}} + if not IS_TRANSFORMERS_V5_OR_LATER + else {} + ), + }, + vision_config={ + "depth": 4, # 32 + "hidden_act": "silu", + "hidden_size": 128, # 1280 + "intermediate_size": 256, # 3420 + "num_heads": 16, + "in_chans": 3, + "out_hidden_size": 128, # 3584 + "patch_size": 14, + "spatial_merge_size": 2, + "temporal_patch_size": 2, + }, + ), + ) +if LLAVA_AVAILABLE: + # https://huggingface.co/llava-hf/llava-1.5-7b-hf + MINI_MODEL_SETUPS["mini_llava"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_llava, + liger_kernel_patch_revert_func=revert_liger_kernel_to_llava, + model_class=LlavaForConditionalGeneration, + mini_model_config=LlavaConfig( + text_config=LlamaConfig( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=2048, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + pretraining_tp=1, + tie_word_embeddings=False, + use_cache=True, + max_position_embeddings=4096, # llava-1.5-7b-hf + rms_norm_eps=1e-05, # llava-1.5-7b-hf + vocab_size=32064, # llava-1.5-7b-hf + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + vision_config=CLIPVisionConfig( + hidden_size=1024, + image_size=336, + intermediate_size=2048, # 4096 + model_type="clip_vision_model", + num_attention_heads=4, # 16 + num_hidden_layers=4, # 24 + patch_size=14, + projection_dim=768, + vocab_size=32000, + ), + vocab_size=32064, + ignore_index=-100, + pad_token_id=4, + image_token_index=3, + projector_hidden_act="gelu", + vision_feature_layer=-2, + vision_feature_select_strategy="default", + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if SMOLLM3_AVAILABLE: + MINI_MODEL_SETUPS["mini_smollm3"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_smollm3, + liger_kernel_patch_revert_func=revert_liger_kernel_to_smollm3, + model_class=SmolLM3ForCausalLM, + mini_model_config=SmolLM3Config( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + pad_token_id=2, # 128000 + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=8192, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + num_key_value_heads=2, # 8 + pretraining_tp=1, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if INTERNVL_AVAILABLE: + MINI_MODEL_SETUPS["mini_internvl"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_internvl, + liger_kernel_patch_revert_func=revert_liger_kernel_to_internvl, + model_class=InternVLForConditionalGeneration, + mini_model_config=InternVLConfig( + text_config=Qwen2Config( + rms_norm_eps=1e-5, + hidden_size=256, # 1024 + intermediate_size=1024, # 4096 + hidden_act="silu", + num_hidden_layers=4, # 24 + num_attention_heads=4, # 16 + num_key_value_heads=2, # 16 + max_position_embeddings=4096, # 8192 + vocab_size=32000, # 151936 + bos_token_id=1, + eos_token_id=2, + pad_token_id=2, + tie_word_embeddings=False, + ), + vision_config={ + "hidden_size": 256, # 1024 + "intermediate_size": 1024, # 4096 + "num_hidden_layers": 4, # 24 + "num_attention_heads": 4, # 16 + }, + image_token_id=10, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if FALCONH1_AVAILABLE: + MINI_MODEL_SETUPS["mini_falcon_h1"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_falcon_h1, + liger_kernel_patch_revert_func=revert_liger_kernel_to_falcon_h1, + model_class=FalconH1ForCausalLM, + mini_model_config=FalconH1Config( + model_type="falcon_h1", + vocab_size=32000, + hidden_size=256, # 4096 + num_hidden_layers=4, # 24 + num_attention_heads=4, # 32 + num_key_value_heads=2, # 8 + intermediate_size=1024, # 11008 + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + mamba_d_ssm=128, # 1024 + mamba_n_heads=16, # 128 + mamba_d_state=32, # 245 + mamba_d_conv=2, # 4 + ), + ) + +if QWEN3NEXT_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3_next"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_next, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_next, + model_class=Qwen3NextForCausalLM, + mini_model_config=Qwen3NextConfig( # Copypaste Qwen3MoeConfig + vocab_size=32000, + hidden_size=896, + intermediate_size=4864, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + attention_bias=False, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + decoder_sparse_step=1, + moe_intermediate_size=768, + num_experts_per_tok=2, + num_experts=8, + norm_topk_prob=False, + output_router_logits=False, + router_aux_loss_coef=0.001, + # config.dtype must be set if fla installed since there's a bug in the original code (No torch.get_current_dtype()) + dtype=torch.float32, + ), + ) + +if QWEN3_5_MOE_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3_5_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_5_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_5_moe, + model_class=Qwen3_5MoeForCausalLM, + mini_model_config=Qwen3_5MoeTextConfig( + vocab_size=32000, + hidden_size=896, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + attention_bias=False, + attention_dropout=0.0, + head_dim=128, + linear_conv_kernel_dim=4, + linear_key_head_dim=64, + linear_value_head_dim=64, + linear_num_key_heads=8, + linear_num_value_heads=8, + moe_intermediate_size=768, + shared_expert_intermediate_size=768, + num_experts_per_tok=2, + num_experts=8, + output_router_logits=False, + router_aux_loss_coef=0.001, + # config.dtype must be set if fla installed since there's a bug in the original code (No torch.get_current_dtype()) + dtype=torch.float32, + ), + ) + +if QWEN3_5_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3_5"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_5, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_5, + model_class=Qwen3_5ForCausalLM, + mini_model_config=Qwen3_5TextConfig( + vocab_size=32000, + hidden_size=896, + intermediate_size=4864, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + attention_bias=False, + attention_dropout=0.0, + head_dim=128, + linear_conv_kernel_dim=4, + linear_key_head_dim=64, + linear_value_head_dim=64, + linear_num_key_heads=8, + linear_num_value_heads=8, + layer_types=["linear_attention", "linear_attention", "linear_attention", "full_attention"], + dtype=torch.float32, + ), + ) + +if HUNYUAN_V1_AVAILABLE: + MINI_MODEL_SETUPS["mini_hunyuan_v1"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_hunyuan_v1_dense, + liger_kernel_patch_revert_func=revert_liger_kernel_to_hunyuan_v1, + model_class=HunYuanDenseV1ForCausalLM, + mini_model_config=HunYuanDenseV1Config( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + num_hidden_layers=4, + hidden_size=896, + intermediate_size=4864, + num_attention_heads=8, + head_dim=112, + rms_norm_eps=1e-6, + tie_word_embeddings=True, + max_position_embeddings=32768, + initializer_range=0.02, + norm_eps=1e-6, + num_key_value_heads=2, + partial_rotary_factor=1.0, + vocab_size=32000, + use_cache=True, + attn_implementation="sdpa", + ), + ) + + MINI_MODEL_SETUPS["mini_hunyuan_v1_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_hunyuan_v1_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_hunyuan_v1_moe, + model_class=HunYuanMoEV1ForCausalLM, + mini_model_config=HunYuanMoEV1Config( + hidden_act="silu", + attention_dropout=0.0, + num_hidden_layers=4, + hidden_size=896, + intermediate_size=4864, + num_attention_heads=8, + head_dim=112, + rms_norm_eps=1e-6, + tie_word_embeddings=True, + max_position_embeddings=32768, + initializer_range=0.02, + norm_eps=1e-6, + num_key_value_heads=2, + partial_rotary_factor=1.0, + vocab_size=32000, + num_experts=8, + moe_topk=2, + use_cache=True, + attn_implementation="sdpa", + ), + ) + +if EXAONE4_AVAILABLE: + MINI_MODEL_SETUPS["mini_exaone4"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_exaone4, + liger_kernel_patch_revert_func=revert_liger_kernel_to_exaone4, + model_class=Exaone4ForCausalLM, + mini_model_config=Exaone4Config( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=896, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=32768, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-5, + tie_word_embeddings=True, + use_cache=True, + vocab_size=32000, + attn_implementation="sdpa", + pad_token_id=None, + ), + ) + + +def create_model(model_name="mini_llama3"): + """ + Create a mini version model + The commented values are the original values + """ + model_config = MINI_MODEL_SETUPS[model_name].mini_model_config + model_class = MINI_MODEL_SETUPS[model_name].model_class + return model_class(model_config) + + +@require_deterministic +def run_mini_model( + model_name="mini_llama3", + num_steps=100, + dtype=torch.float32, + lr=1e-5, + with_liger=False, +): + # If we move it to the beginning of test_mini_model, the two runs are initialized with different weights. + # This is due to RNG (Random Number Generator). The formula of RNG progression is x_(n+1) = (a * x_n + c) % m + # Everytime RNG is used, like randomly initialzing weight, the RNG progresses to the next state. + # Therefore, we have to reset RNG before we create the model to ensure the weight initialization started from the same RNG state. + + set_seed(42) + + revert_kwargs = {"model_config": MINI_MODEL_SETUPS[model_name]} + if "mllama" in model_name: + revert_kwargs["model_type"] = "causal_lm" + + if with_liger is True: + kwargs = { + "rope": True, + "rms_norm": True, + } + + if "glm4" in model_name or "qwen3_next" in model_name or "qwen3_5" in model_name: + kwargs["rope"] = False + + model_supports_layer_norm = "qwen2_vl" in model_name + if model_supports_layer_norm: + kwargs["layer_norm"] = True + + if "gemma" in model_name: + kwargs["geglu"] = True + else: + kwargs["swiglu"] = True + + if "llava" in model_name: + apply_liger_kernel_to_llama(**kwargs) + + # fused_linear_cross_entropy is not supported in mini_granite3 + kwargs["fused_linear_cross_entropy"] = True if model_name != "mini_granite3" else False + kwargs["cross_entropy"] = False + + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) + else: + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs) + + model = create_model(model_name).to(dtype).to(device) + + train_dataset = load_from_disk(DEFAULT_DATASET_PATH) + loader = DataLoader(train_dataset, batch_size=16, shuffle=False, collate_fn=simple_collate_fn) + loader_iter = iter(loader) + optimizer = torch.optim.AdamW(model.parameters(), lr=lr) + + loss_list = [] + for i in range(num_steps): + batch = next(loader_iter).to(model.device) + optimizer.zero_grad() + output = model(**batch) + output.loss.backward() + optimizer.step() + print(f"Step {i}, Loss: {output.loss.item()}") + loss_list.append(output.loss.item()) + + model.eval() + eval_batch = next(loader_iter).to(model.device) + if with_liger: + eval_batch["skip_logits"] = False + with torch.no_grad(): + eval_output = model(**eval_batch) + print(f"Eval Loss: {eval_output.loss.item()}") + loss_list.append(eval_output.loss.item()) + topk_logprobs = get_topk(get_logprobs(eval_output.logits)) + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs) + return { + "loss": loss_list, + "topk_logprobs": topk_logprobs.values, + "model": model, + } + + +@pytest.mark.parametrize( + "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logprobs_atol, logprobs_rtol, param_atol, param_rtol", + [ + pytest.param( + "mini_llama4", # llama4 requires slightly larger tolerances to pass this test after bug fix to llama4 in transformers v5.0.0 + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-3, + 5e-3, + 1e-3, + 5e-3, + 1e-5, + marks=[ + pytest.mark.skipif( + not LLAMA4_AVAILABLE, + reason="Llama4 not available in this version of trasnformers", + ), + # pytest.mark.xfail( + # reason=( + # "RuntimeError: Expected query, key, and value to have the same dtype, but got query.dtype:" + # " float key.dtype: c10::BFloat16 and value.dtype: c10::BFloat16 instead." + # ) + # ), + ], + ), + ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 5e-3, 1e-5, 5e-3, 1e-5), + pytest.param( + "mini_llava", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=[ + pytest.mark.skipif( + not LLAVA_AVAILABLE, + reason="LLaVa not available in this version of transformers", + ), + pytest.mark.skipif( + version.parse(transformers.__version__) < version.parse("4.52.0"), + reason="LLaVa doesn't materialize logits in transformers<=4.52.0 so we can't test it", + ), + ], + ), + pytest.param( + "mini_mllama", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not MLLAMA_AVAILABLE, + reason="Mllama not available in this version of transformers", + ), + ), + pytest.param( + "mini_gemma3_text", + 32, + 1e-5, + torch.float32, + 1e-8, + 1e-4, + 5e-2, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not GEMMA3_AVAILABLE, + reason="Gemma3 not available in this version of transformers", + ), + ), + ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), + pytest.param( + "mini_qwen3", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not QWEN3_AVAILABLE, + reason="Qwen3 not available in this version of transformers", + ), + ), + pytest.param( + "mini_qwen3_moe", + 32, + 1e-5, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not QWEN3_AVAILABLE, + reason="Qwen3 not available in this version of transformers", + ), + ), + pytest.param( + "mini_gpt_oss", + 32, + 1e-5, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not GPT_OSS_AVAILABLE, + reason="GPT-OSS not available in this version of transformers", + ), + ), + pytest.param( # qwen2_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0 + "mini_qwen2_vl", + 32, + 1e-4, + torch.float32, + 1e-5, # 1e-8, + 1e-1, # 1e-5, + 5e-3, # 5e-3, + 1e-5, # 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not QWEN2_VL_AVAILABLE, + reason="Qwen2-VL not available in this version of transformers", + ), + ), + # TODO: logits tolerances are significantly larger than the other tests, need to investigate + pytest.param( # qwen2_5_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0 + "mini_qwen2_5_vl", + 32, + 1e-4, + torch.float32, + 1e-5, # 1e-8, + 1e-1, # 1e-5, + 5e-3, # 5e-3, + 1e-5, # 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not QWEN2_5_VL_AVAILABLE, + reason="Qwen2.5-VL not available in this version of transformers", + ), + ), + pytest.param( + "mini_qwen3_vl", + 32, + 1e-4, + torch.float32, + 1e-5, # 1e-8, + 1e-1, # 1e-5, + 5e-3, # 5e-3, + 1e-5, # 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not QWEN3_VL_AVAILABLE, + reason="Qwen3-VL not available in this version of transformers", + ), + ), + pytest.param( + "mini_qwen3_vl_moe", + 32, + 1e-4, + torch.float32, + 1e-5, + 1e-1, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not QWEN3_VL_MOE_AVAILABLE, + reason="Qwen3-VL-MoE not available in this version of transformers", + ), + ), + pytest.param( + "mini_olmo2", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not OLMO2_AVAILABLE, + reason="OLMO2 not available in this version of transformers", + ), + ), + pytest.param( + "mini_olmo3", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not OLMO3_AVAILABLE, + reason="OLMO3 not available in this version of transformers", + ), + ), + pytest.param( + "mini_glm4", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not GLM4_AVAILABLE, + reason="Glm4 not available in this version of transformers", + ), + ), + pytest.param( + "mini_glm4v", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not GLM4V_AVAILABLE, + reason="Glm4v not available in this version of transformers", + ), + ), + pytest.param( + "mini_glm4v_moe", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-3, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=[ + pytest.mark.skipif( + not GLM4V_MOE_AVAILABLE, + reason="Glm4v_moe not available in this version of transformers", + ), + ], + ), + ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), + pytest.param( + "mini_mistral", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=[], + ), + # TODO: mixtral is flaky so disable the test for now + # ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5), + # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way) + ("mini_gemma1", 32, 1e-5, torch.float32, 1e-8, 1e-4, 5e-2, 1e-5, 5e-3, 1e-5), + ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), + ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), + pytest.param( + "mini_granite3", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-4, + 4e-2, # 4e-3 + 1e-5, # 1e-5 + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not GRANITE_AVAILABLE, + reason="Granite not available in this version of transformers", + ), + ), + pytest.param( + "mini_smollm3", + 32, + 1e-4, + torch.bfloat16, + 1e-3, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=pytest.mark.skipif( + not SMOLLM3_AVAILABLE, + reason="Smollm3 not available in this version of transformers", + ), + ), + pytest.param( + "mini_internvl", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not INTERNVL_AVAILABLE, + reason="InternVL not available in this version of transformers", + ), + ), + pytest.param( + "mini_falcon_h1", + 32, + 1e-5, + torch.float32, + 1e-8, + 1e-4, + 4e-2, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not FALCONH1_AVAILABLE, + reason="FalconH1 not available in this version of transformers", + ), + ), + pytest.param( + "mini_qwen3_next", + 32, + 1e-5, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=[ + pytest.mark.skipif( + not QWEN3NEXT_AVAILABLE, + reason="Qwen3Next not available in this version of transformers", + ), + pytest.mark.skip( + reason="flash-linear-attention's ChunkGatedDeltaRuleFunction does not support float32.\n" + + " Torch's implementation takes too long" + ), + ], + ), + pytest.param( + "mini_qwen3_5_moe", + 32, + 1e-5, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=[ + pytest.mark.skipif( + not QWEN3_5_MOE_AVAILABLE, + reason="Qwen3_5Moe not available in this version of transformers", + ), + pytest.mark.skip( + reason="flash-linear-attention's ChunkGatedDeltaRuleFunction does not support float32.\n" + + " Torch's implementation takes too long" + ), + ], + ), + pytest.param( + "mini_qwen3_5", + 32, + 1e-5, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=[ + pytest.mark.skipif( + not QWEN3_5_AVAILABLE, + reason="Qwen3_5 not available in this version of transformers", + ), + pytest.mark.skip( + reason="flash-linear-attention's ChunkGatedDeltaRuleFunction does not support float32.\n" + + " Torch's implementation takes too long" + ), + ], + ), + pytest.param( + "mini_hunyuan_v1", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not HUNYUAN_V1_AVAILABLE, + reason="Hunyuan_v1 not available in this version of transformers", + ), + ), + pytest.param( + "mini_hunyuan_v1_moe", + 32, + 1e-5, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not HUNYUAN_V1_AVAILABLE, + reason="Hunyuan_v1_moe not available in this version of transformers", + ), + ), + pytest.param( + "mini_exaone4", + 32, + 1e-5, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not EXAONE4_AVAILABLE, + reason="EXAONE4 not available in this version of transformers", + ), + ), + ], +) +def test_mini_model( + model_name, + num_steps, + lr, + dtype, + loss_atol, + loss_rtol, + logprobs_atol, + logprobs_rtol, + param_atol, + param_rtol, +): + # Non-liger models should be initialized and tested first to avoid the module being overridden + + expected_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr) + + actual_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True) + + # Compare every step of the loss + assert_verbose_allclose( + torch.tensor([expected_output["loss"]]), + torch.tensor([actual_output["loss"]]), + atol=loss_atol, + rtol=loss_rtol, + extra_info="[Loss]", + ) + + # Compare the topk logprobs from evaluation step + if expected_output["topk_logprobs"] is not None and actual_output["topk_logprobs"] is not None: + assert_verbose_allclose( + expected_output["topk_logprobs"], + actual_output["topk_logprobs"], + atol=logprobs_atol, + rtol=logprobs_rtol, + extra_info="[Top k logprobs]", + ) + + # Compare the params from the last step + # Iterate over the model's parameters and compare them + for expected_param, actual_param in zip( + expected_output["model"].named_parameters(), + actual_output["model"].named_parameters(), + ): + assert_verbose_allclose( + expected_param[1], + actual_param[1], + atol=param_atol, + rtol=param_rtol, + extra_info="[Model parameters]", + ) diff --git a/test/convergence/fp32/test_mini_models_multimodal.py b/test/convergence/fp32/test_mini_models_multimodal.py new file mode 100755 index 0000000000000000000000000000000000000000..f3e59bc3a1ae8ee2353f657f8bf1e67c967d5df1 --- /dev/null +++ b/test/convergence/fp32/test_mini_models_multimodal.py @@ -0,0 +1,1934 @@ +import functools +import os + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # Ensure deterministic behavior with CuBLAS + +import pytest +import torch +import transformers + +from datasets import load_dataset +from packaging import version +from torch.utils.data import DataLoader +from transformers import PreTrainedTokenizerFast +from transformers.models.siglip.configuration_siglip import SiglipVisionConfig + +from liger_kernel.transformers import apply_liger_kernel_to_gemma3 +from liger_kernel.transformers import apply_liger_kernel_to_internvl +from liger_kernel.transformers import apply_liger_kernel_to_llama4 +from liger_kernel.transformers import apply_liger_kernel_to_llava +from liger_kernel.transformers import apply_liger_kernel_to_mllama +from liger_kernel.transformers import apply_liger_kernel_to_paligemma +from liger_kernel.transformers import apply_liger_kernel_to_pixtral +from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl +from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5 +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl_moe +from liger_kernel.transformers import apply_liger_kernel_to_smolvlm +from liger_kernel.utils import infer_device +from test.utils import FAKE_CONFIGS_PATH +from test.utils import UNTOKENIZED_DATASET_PATH +from test.utils import MiniModelConfig +from test.utils import assert_verbose_allclose +from test.utils import get_logprobs +from test.utils import get_topk +from test.utils import is_torchvision_available +from test.utils import load_image_processing_config +from test.utils import load_processor_config +from test.utils import load_tokenizer_config +from test.utils import multimodal_collate_fn +from test.utils import require_deterministic +from test.utils import revert_liger_kernel_to_gemma3 +from test.utils import revert_liger_kernel_to_internvl +from test.utils import revert_liger_kernel_to_llama4 +from test.utils import revert_liger_kernel_to_llava +from test.utils import revert_liger_kernel_to_mllama +from test.utils import revert_liger_kernel_to_Paligemma +from test.utils import revert_liger_kernel_to_pixtral +from test.utils import revert_liger_kernel_to_qwen2_5_vl +from test.utils import revert_liger_kernel_to_qwen2_vl +from test.utils import revert_liger_kernel_to_qwen3_5 +from test.utils import revert_liger_kernel_to_qwen3_vl +from test.utils import revert_liger_kernel_to_qwen3_vl_moe +from test.utils import revert_liger_kernel_to_smolvlm2 +from test.utils import set_seed +from test.utils import train_bpe_tokenizer + +IS_TRANSFORMERS_V5_OR_LATER = version.parse(transformers.__version__) >= version.parse("5.0.0") + +if IS_TRANSFORMERS_V5_OR_LATER: + from transformers.models.gemma.tokenization_gemma import GemmaTokenizer +else: + from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast as GemmaTokenizer + +try: + # Qwen2-VL is only available in transformers>=4.52.4 + import transformers + + from packaging import version + + if IS_TRANSFORMERS_V5_OR_LATER: + from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer + else: + from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast as Qwen2Tokenizer + from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig + from transformers.models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration + from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLProcessor + from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor + + QWEN2_VL_AVAILABLE = version.parse(transformers.__version__) >= version.parse("4.52.4") +except ImportError: + QWEN2_VL_AVAILABLE = False + +try: + # Qwen2.5-VL is only available in transformers>4.48.2 + import transformers + + from packaging import version + + if IS_TRANSFORMERS_V5_OR_LATER: + from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer + else: + from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast as Qwen2Tokenizer + from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration + from transformers.models.qwen2_5_vl.processing_qwen2_5_vl import Qwen2_5_VLProcessor + from transformers.models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor + from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor + + QWEN2_5_VL_AVAILABLE = version.parse(transformers.__version__) >= version.parse("4.52.4") +except ImportError: + QWEN2_5_VL_AVAILABLE = False + + +try: + if IS_TRANSFORMERS_V5_OR_LATER: + from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer + else: + from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast as Qwen2Tokenizer + from transformers.models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor + from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig + from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLTextConfig + from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLVisionConfig + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration + from transformers.models.qwen3_vl.processing_qwen3_vl import Qwen3VLProcessor + from transformers.models.qwen3_vl.video_processing_qwen3_vl import Qwen3VLVideoProcessor + + QWEN3_VL_AVAILABLE = True +except ImportError: + QWEN3_VL_AVAILABLE = False + + +try: + from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeConfig + from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeTextConfig + from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeVisionConfig + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration + + QWEN3_VL_MOE_AVAILABLE = True +except ImportError: + QWEN3_VL_MOE_AVAILABLE = False + +try: + if IS_TRANSFORMERS_V5_OR_LATER: + from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer + else: + from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast as Qwen2Tokenizer + from transformers.models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor + from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig + from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLTextConfig + from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLVisionConfig + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration + from transformers.models.qwen3_vl.processing_qwen3_vl import Qwen3VLProcessor + from transformers.models.qwen3_vl.video_processing_qwen3_vl import Qwen3VLVideoProcessor + + QWEN3_VL_AVAILABLE = True +except ImportError: + QWEN3_VL_AVAILABLE = False + + +try: + from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeConfig + from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeTextConfig + from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeVisionConfig + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration + + QWEN3_VL_MOE_AVAILABLE = True +except ImportError: + QWEN3_VL_MOE_AVAILABLE = False + +try: + from transformers.models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor + from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5Config + from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig + from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5VisionConfig + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForConditionalGeneration + from transformers.models.qwen3_vl.processing_qwen3_vl import Qwen3VLProcessor + from transformers.models.qwen3_vl.video_processing_qwen3_vl import Qwen3VLVideoProcessor + + QWEN3_5_AVAILABLE = True +except ImportError: + QWEN3_5_AVAILABLE = False + +try: + # Mllama is only available in transformers>=4.45.0 + from transformers.models.mllama.configuration_mllama import MllamaConfig + from transformers.models.mllama.configuration_mllama import MllamaTextConfig + from transformers.models.mllama.configuration_mllama import MllamaVisionConfig + from transformers.models.mllama.image_processing_mllama import MllamaImageProcessor + from transformers.models.mllama.modeling_mllama import MllamaForConditionalGeneration + from transformers.models.mllama.processing_mllama import MllamaProcessor + + MLLAMA_AVAILABLE = True +except ImportError: + MLLAMA_AVAILABLE = False + +try: + from transformers import CLIPImageProcessor + from transformers import CLIPVisionConfig + from transformers import LlamaConfig + from transformers.models.llava.configuration_llava import LlavaConfig + from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration + from transformers.models.llava.processing_llava import LlavaProcessor + + from liger_kernel.transformers import apply_liger_kernel_to_llama + + LLAVA_AVAILABLE = True +except ImportError: + LLAVA_AVAILABLE = False + +try: + from transformers.models.llama4.configuration_llama4 import Llama4Config + from transformers.models.llama4.configuration_llama4 import Llama4TextConfig + from transformers.models.llama4.configuration_llama4 import Llama4VisionConfig + from transformers.models.llama4.image_processing_llama4_fast import Llama4ImageProcessorFast + from transformers.models.llama4.modeling_llama4 import Llama4ForConditionalGeneration + from transformers.models.llama4.processing_llama4 import Llama4Processor + + LLAMA4_AVAILABLE = True + +except ImportError: + LLAMA4_AVAILABLE = False + +try: + import transformers + + from packaging import version + from transformers.models.gemma.configuration_gemma import GemmaConfig + from transformers.models.gemma2.configuration_gemma2 import Gemma2Config + from transformers.models.paligemma.configuration_paligemma import PaliGemmaConfig + from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration + from transformers.models.paligemma.processing_paligemma import PaliGemmaProcessor + from transformers.models.siglip.image_processing_siglip import SiglipImageProcessor + + PALIGEMMA_AVAILABLE = True +except ImportError: + PALIGEMMA_AVAILABLE = False + +try: + # Gemma3 is only available in transformers>=4.50.0 + from transformers.models.gemma3.configuration_gemma3 import Gemma3Config + from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig + from transformers.models.gemma3.image_processing_gemma3 import Gemma3ImageProcessor + from transformers.models.gemma3.modeling_gemma3 import Gemma3ForConditionalGeneration + from transformers.models.gemma3.processing_gemma3 import Gemma3Processor + + GEMMA3_AVAILABLE = True +except ImportError: + GEMMA3_AVAILABLE = False + +try: + # InternVL is only available in transformers>=4.52.1 + from transformers.models.got_ocr2.image_processing_got_ocr2_fast import GotOcr2ImageProcessorFast + from transformers.models.internvl.configuration_internvl import InternVLConfig + from transformers.models.internvl.modeling_internvl import InternVLForConditionalGeneration + from transformers.models.internvl.processing_internvl import InternVLProcessor + from transformers.models.internvl.video_processing_internvl import InternVLVideoProcessor + from transformers.models.qwen2.configuration_qwen2 import Qwen2Config + + INTERNVL_AVAILABLE = True +except ImportError: + INTERNVL_AVAILABLE = False + +try: + # SmolVLM2 is only available in transformers>=4.50.0 + from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer + from transformers.models.smolvlm.configuration_smolvlm import SmolVLMConfig + from transformers.models.smolvlm.image_processing_smolvlm import SmolVLMImageProcessor + from transformers.models.smolvlm.modeling_smolvlm import SmolVLMForConditionalGeneration + from transformers.models.smolvlm.processing_smolvlm import SmolVLMProcessor + from transformers.models.smolvlm.video_processing_smolvlm import SmolVLMVideoProcessor + + SMOLVLM2_AVAILABLE = True +except ImportError: + SMOLVLM2_AVAILABLE = False + +try: + from transformers.models.pixtral.configuration_pixtral import PixtralVisionConfig + from transformers.models.pixtral.modeling_pixtral import PixtralVisionModel + + PIXTRAL_AVAILABLE = True +except ImportError: + PIXTRAL_AVAILABLE = False + +try: + from num2words import num2words # noqa: F401 + + NUM2WORDS_AVAILABLE = True +except ImportError: + NUM2WORDS_AVAILABLE = False + + +device = infer_device() + +torch.use_deterministic_algorithms(True) + +# Only setting torch.use_deterministic_algorithms(True) throws the following error: +# RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or `at::Context::setDeterministicAlgorithms(true)`, +# but this operation is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this case, you must set an +# environment variable before running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. For more information, +# go to https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + +TEST_IMAGE_DIM = 64 + +MINI_MODEL_SETUPS = {} + +if LLAMA4_AVAILABLE: + MINI_MODEL_SETUPS["mini_llama4"] = MiniModelConfig( + liger_kernel_patch_func=functools.partial(apply_liger_kernel_to_llama4, fused_linear_cross_entropy=False), + liger_kernel_patch_revert_func=revert_liger_kernel_to_llama4, + model_class=Llama4ForConditionalGeneration, + mini_model_config=Llama4Config( + image_token_index=8, + vision_config=Llama4VisionConfig( + attn_implementation_autoset=True, + attention_dropout=0.0, + hidden_act="gelu", + hidden_size=512, # 1280 + image_size=560, # 560 + initializer_range=0.02, + intermediate_layers_indices=[2], # [3, 7, 15, etc...] + intermediate_size=2048, # 5120 + max_num_tiles=1, # 4 + norm_eps=1e-5, + num_attention_heads=4, # 16 + num_channels=3, + num_global_layers=2, # 8 + num_hidden_layers=8, # 32 + patch_size=280, # 14 + supported_aspect_ratios=[[1, 1]], # [[1, 1], [1, 2], etc... ] + vision_output_dim=4096, # 7680 + ), + text_config=Llama4TextConfig( + bos_token_id=0, + eos_token_id=0, + pad_token_id=0, + cross_attention_layers=[2], # [3, 8, 13, 18, etc...] + dropout=0, + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=131_072, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 40 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + ), + attn_implementation="sdpa", + ), + ) + + +if MLLAMA_AVAILABLE: + MINI_MODEL_SETUPS["mini_mllama"] = MiniModelConfig( + liger_kernel_patch_func=functools.partial(apply_liger_kernel_to_mllama, fused_linear_cross_entropy=False), + liger_kernel_patch_revert_func=revert_liger_kernel_to_mllama, + model_class=MllamaForConditionalGeneration, + mini_model_config=MllamaConfig( + vision_config=MllamaVisionConfig( + hidden_act="gelu", + hidden_size=512, # 1280 + image_size=560, # 560 + initializer_range=0.02, + intermediate_layers_indices=[2], # [3, 7, 15, etc...] + intermediate_size=2048, # 5120 + max_num_tiles=1, # 4 + norm_eps=1e-5, + num_attention_heads=4, # 16 + num_channels=3, + num_global_layers=2, # 8 + num_hidden_layers=8, # 32 + patch_size=140, # 14 + supported_aspect_ratios=[[1, 1]], # [[1, 1], [1, 2], etc... ] + vision_output_dim=1024, # 7680 + ), + text_config=MllamaTextConfig( + bos_token_id=0, + eos_token_id=0, + pad_token_id=0, + cross_attention_layers=[2], # [3, 8, 13, 18, etc...] + dropout=0, + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=131_072, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 40 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-5, + rope_scaling=dict( + factor=8.0, + high_freq_factor=4.0, + low_freq_factor=1.0, + original_max_position_embeddings=8192, + rope_type="llama3", + ) + if not IS_TRANSFORMERS_V5_OR_LATER + else None, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + ), + image_token_index=1, # NOTE: outside the vocab size + attn_implementation="sdpa", + ), + ) + +if PALIGEMMA_AVAILABLE: + MINI_MODEL_SETUPS["mini_paligemma"] = MiniModelConfig( + liger_kernel_patch_func=functools.partial(apply_liger_kernel_to_paligemma, fused_linear_cross_entropy=False), + liger_kernel_patch_revert_func=revert_liger_kernel_to_Paligemma, + model_class=PaliGemmaForConditionalGeneration, + mini_model_config=PaliGemmaConfig( + vision_config=SiglipVisionConfig( + attention_dropout=0.0, + hidden_act="gelu_pytorch_tanh", + hidden_size=1152, + image_size=224, + intermediate_size=2048, # 4304 + layer_norm_eps=1e-06, + num_attention_heads=4, # 16 + num_channels=3, + num_hidden_layers=4, # 27 + num_image_tokens=256, + num_positions=256, + patch_size=14, + projection_dim=1024, # 2304 + ), + text_config=GemmaConfig( + vocab_size=32000, # 256000 + hidden_size=1024, # 3072 + intermediate_size=2048, # 24576 + num_hidden_layers=4, # 28 + num_attention_heads=4, # 16 + num_key_value_heads=4, # 16 + head_dim=256, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + pad_token_id=0, + # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + tie_word_embeddings=True, + attention_bias=False, + attention_dropout=0.0, + ), + image_token_index=4, # NOTE: outside the vocab size + attn_implementation="eager", + vocab_size=32000, + projection_dim=1024, + ), + ) + + MINI_MODEL_SETUPS["mini_paligemma2"] = MiniModelConfig( + liger_kernel_patch_func=functools.partial(apply_liger_kernel_to_paligemma, fused_linear_cross_entropy=False), + liger_kernel_patch_revert_func=revert_liger_kernel_to_Paligemma, + model_class=PaliGemmaForConditionalGeneration, + mini_model_config=PaliGemmaConfig( + vision_config=SiglipVisionConfig( + attention_dropout=0.0, + hidden_act="gelu_pytorch_tanh", + hidden_size=1152, + image_size=224, + intermediate_size=2048, # 4304 + layer_norm_eps=1e-06, + num_attention_heads=4, # 16 + num_channels=3, + num_hidden_layers=4, # 27 + num_image_tokens=256, + num_positions=256, + patch_size=14, + projection_dim=1024, # 2304 + ), + text_config=Gemma2Config( + vocab_size=32000, # 256000 + hidden_size=1024, # 3072 + intermediate_size=2048, # 24576 + num_hidden_layers=4, # 28 + num_attention_heads=4, # 16 + num_key_value_heads=4, # 16 + head_dim=256, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + pad_token_id=0, + # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + tie_word_embeddings=True, + attention_bias=False, + attention_dropout=0.0, + ), + image_token_index=4, # NOTE: outside the vocab size + attn_implementation="eager", + vocab_size=32000, + projection_dim=1024, + ), + ) + + +if GEMMA3_AVAILABLE: + MINI_MODEL_SETUPS["mini_gemma3"] = MiniModelConfig( + liger_kernel_patch_func=functools.partial(apply_liger_kernel_to_gemma3, fused_linear_cross_entropy=False), + liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma3, + model_class=Gemma3ForConditionalGeneration, + mini_model_config=Gemma3Config( + vision_config=SiglipVisionConfig( + attention_dropout=0.0, + hidden_act="gelu_pytorch_tanh", + hidden_size=1152, + image_size=224, + intermediate_size=2048, # 4304 + layer_norm_eps=1e-06, + num_attention_heads=4, # 16 + num_channels=3, + num_hidden_layers=4, # 27 + num_image_tokens=256, + num_positions=256, + patch_size=14, + ).to_dict(), + text_config=Gemma3TextConfig( + vocab_size=32000, # 256000 + hidden_size=1024, # 3072 + intermediate_size=2048, # 24576 + num_hidden_layers=4, # 28 + num_attention_heads=4, # 16 + num_key_value_heads=4, # 16 + head_dim=256, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + tie_word_embeddings=True, + attention_bias=False, + attention_dropout=0.0, + ), + image_token_index=5, # NOTE: outside the vocab size + boi_token_index=4, + eoi_token_index=6, + ), + ) + +if QWEN2_VL_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen2_vl"] = MiniModelConfig( + liger_kernel_patch_func=functools.partial(apply_liger_kernel_to_qwen2_vl, fused_linear_cross_entropy=False), + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen2_vl, + model_class=Qwen2VLForConditionalGeneration, + mini_model_config=Qwen2VLConfig( + attention_dropout=0.0, + # Token Ids and vocab size must match those in the tokenizer/processor + # test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json + bos_token_id=0, + eos_token_id=0, + vision_start_token_id=1, + vision_end_token_id=2, + vision_token_id=3, + image_token_id=4, + video_token_id=5, + hidden_act="silu", + hidden_size=1024, # 8192 + initializer_range=0.02, + intermediate_size=1024, # 29568 + max_position_embeddings=32768, + max_window_layers=4, # 80 + num_attention_heads=8, # 64 + num_hidden_layers=4, # 80 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-6, # 1e-5 + **( + dict(rope_parameters=dict(mrope_section=[16, 24, 24])) # (temporal, height, width) + if IS_TRANSFORMERS_V5_OR_LATER + else dict(rope_scaling=dict(type="mrope", mrope_section=[16, 24, 24])) + ), + sliding_window=4096, + tie_word_embeddings=True, + use_cache=False, # True + vocab_size=32000, # 152064, + use_sliding_window=False, + vision_config={ + "depth": 4, # 32 + "embed_dim": 128, # 1280 + "mlp_ratio": 1, + "num_heads": 8, # 16 + "in_chans": 3, + "hidden_size": 1024, # 1536 + }, + attn_implementation="sdpa", + ), + ) + +if LLAVA_AVAILABLE: + # https://huggingface.co/llava-hf/llava-1.5-7b-hf + MINI_MODEL_SETUPS["mini_llava"] = MiniModelConfig( + liger_kernel_patch_func=functools.partial(apply_liger_kernel_to_llava, fused_linear_cross_entropy=False), + liger_kernel_patch_revert_func=revert_liger_kernel_to_llava, + model_class=LlavaForConditionalGeneration, + mini_model_config=LlavaConfig( + text_config=LlamaConfig( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=2048, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + pretraining_tp=1, + tie_word_embeddings=False, + use_cache=True, + max_position_embeddings=4096, # llava-1.5-7b-hf + rms_norm_eps=1e-05, # llava-1.5-7b-hf + vocab_size=32064, # llava-1.5-7b-hf + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + vision_config=CLIPVisionConfig( + hidden_size=1024, + image_size=336, + intermediate_size=2048, # 4096 + model_type="clip_vision_model", + num_attention_heads=4, # 16 + num_hidden_layers=4, # 24 + patch_size=14, + projection_dim=768, + vocab_size=32000, + ), + vocab_size=32064, + ignore_index=-100, + pad_token_id=4, + image_token_index=3, + projector_hidden_act="gelu", + vision_feature_layer=-2, + vision_feature_select_strategy="default", + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if INTERNVL_AVAILABLE: + MINI_MODEL_SETUPS["mini_internvl"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_internvl, + liger_kernel_patch_revert_func=revert_liger_kernel_to_internvl, + model_class=InternVLForConditionalGeneration, + mini_model_config=InternVLConfig( + text_config=Qwen2Config( + rms_norm_eps=1e-5, + hidden_size=256, # 1024 + intermediate_size=1024, # 4096 + hidden_act="silu", + num_hidden_layers=4, # 24 + num_attention_heads=4, # 16 + num_key_value_heads=2, # 16 + max_position_embeddings=4096, # 8192 + vocab_size=32000, # 151936 + bos_token_id=1, + eos_token_id=2, + pad_token_id=2, + tie_word_embeddings=False, + ), + vision_config={ + "hidden_size": 256, # 1024 + "intermediate_size": 1024, # 4096 + "num_hidden_layers": 4, # 24 + "num_attention_heads": 4, # 16 + }, + image_token_id=24, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if SMOLVLM2_AVAILABLE: + MINI_MODEL_SETUPS["mini_smolvlm2"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_smolvlm, + liger_kernel_patch_revert_func=revert_liger_kernel_to_smolvlm2, + model_class=SmolVLMForConditionalGeneration, + mini_model_config=SmolVLMConfig( + text_config=LlamaConfig( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + pad_token_id=2, + hidden_act="silu", + hidden_size=576, # 576 for 256M model + initializer_range=0.041666666666666664, + intermediate_size=1536, # 1536 for 256M model + max_position_embeddings=8192, + num_attention_heads=9, # 9 for 256M model + num_hidden_layers=4, # 30 -> reduced to 4 for testing + num_key_value_heads=3, # 3 for 256M model + rms_norm_eps=1e-5, + tie_word_embeddings=False, + vocab_size=49280, + ), + vision_config={ + "hidden_size": 768, + "intermediate_size": 3072, + "num_hidden_layers": 4, # 12 -> reduced to 4 for testing + "num_attention_heads": 12, + "image_size": 512, + "patch_size": 16, + }, + image_token_id=49190, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if QWEN2_5_VL_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen2_5_vl"] = MiniModelConfig( + liger_kernel_patch_func=functools.partial(apply_liger_kernel_to_qwen2_5_vl, fused_linear_cross_entropy=False), + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen2_5_vl, + model_class=Qwen2_5_VLForConditionalGeneration, + mini_model_config=Qwen2_5_VLConfig( + attention_dropout=0.0, + # Token Ids and vocab size must match those in the tokenizer/processor + # test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json + bos_token_id=0, + eos_token_id=0, + vision_start_token_id=1, + vision_end_token_id=2, + vision_token_id=3, + image_token_id=4, + video_token_id=5, + hidden_act="silu", + hidden_size=1024, # 8192 + initializer_range=0.02, + intermediate_size=1024, # 29568 + max_position_embeddings=32768, + max_window_layers=4, # 80 + num_attention_heads=8, # 64 + num_hidden_layers=4, # 80 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-6, # 1e-5 + **( + dict(rope_parameters=dict(mrope_section=[16, 24, 24])) # (temporal, height, width) + if IS_TRANSFORMERS_V5_OR_LATER + else dict(rope_scaling=dict(type="mrope", mrope_section=[16, 24, 24])) + ), + sliding_window=4096, + tie_word_embeddings=True, + use_cache=False, # True + vocab_size=32000, # 152064, + use_sliding_window=False, + vision_config={ + "depth": 4, # 32 + "hidden_size": 128, # 1280 + "num_heads": 16, + "in_chans": 3, + "out_hidden_size": 1024, + }, + attn_implementation="sdpa", + ), + ) + +if QWEN3_VL_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3_vl"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_vl, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_vl, + model_class=Qwen3VLForConditionalGeneration, + mini_model_config=Qwen3VLConfig( + attn_implementation="sdpa", + image_token_id=4, + video_token_id=5, + vision_start_token_id=1, + vision_end_token_id=2, + tie_word_embeddings=True, + vision_config=Qwen3VLVisionConfig( + depth=4, + hidden_size=256, + hidden_act="gelu_pytorch_tanh", + intermediate_size=512, + num_heads=4, + in_channels=3, + patch_size=16, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=512, + num_position_embeddings=256, + deepstack_visual_indexes=[1, 2, 3], + initializer_range=0.02, + ).to_dict(), + text_config=Qwen3VLTextConfig( + vocab_size=32000, + hidden_size=512, + intermediate_size=2048, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + head_dim=64, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=False, + tie_word_embeddings=True, + attention_dropout=0.0, + attention_bias=False, + ).to_dict(), + ), + ) + +if QWEN3_VL_MOE_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3_vl_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_vl_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_vl_moe, + model_class=Qwen3VLMoeForConditionalGeneration, + mini_model_config=Qwen3VLMoeConfig( + attn_implementation="sdpa", + image_token_id=4, + video_token_id=5, + vision_start_token_id=1, + vision_end_token_id=2, + tie_word_embeddings=True, + vision_config=Qwen3VLMoeVisionConfig( + depth=4, + hidden_size=256, + hidden_act="gelu_pytorch_tanh", + intermediate_size=512, + num_heads=4, + in_channels=3, + patch_size=16, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=512, + num_position_embeddings=256, + deepstack_visual_indexes=[1, 2, 3], + initializer_range=0.02, + ).to_dict(), + text_config=Qwen3VLMoeTextConfig( + vocab_size=32000, + hidden_size=512, + intermediate_size=2048, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + head_dim=64, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=False, + tie_word_embeddings=True, + attention_dropout=0.0, + attention_bias=False, + decoder_sparse_step=1, + moe_intermediate_size=1024, + num_experts_per_tok=2, + num_experts=4, + norm_topk_prob=False, + output_router_logits=False, + router_aux_loss_coef=0.001, + ).to_dict(), + ), + ) + + +if QWEN3_VL_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3_vl"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_vl, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_vl, + model_class=Qwen3VLForConditionalGeneration, + mini_model_config=Qwen3VLConfig( + attn_implementation="sdpa", + image_token_id=4, + video_token_id=5, + vision_start_token_id=1, + vision_end_token_id=2, + tie_word_embeddings=True, + vision_config=Qwen3VLVisionConfig( + depth=4, + hidden_size=256, + hidden_act="gelu_pytorch_tanh", + intermediate_size=512, + num_heads=4, + in_channels=3, + patch_size=16, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=512, + num_position_embeddings=256, + deepstack_visual_indexes=[1, 2, 3], + initializer_range=0.02, + ).to_dict(), + text_config=Qwen3VLTextConfig( + vocab_size=32000, + hidden_size=512, + intermediate_size=2048, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + head_dim=64, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=False, + tie_word_embeddings=True, + attention_dropout=0.0, + attention_bias=False, + ).to_dict(), + ), + ) + +if QWEN3_VL_MOE_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3_vl_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_vl_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_vl_moe, + model_class=Qwen3VLMoeForConditionalGeneration, + mini_model_config=Qwen3VLMoeConfig( + attn_implementation="sdpa", + image_token_id=4, + video_token_id=5, + vision_start_token_id=1, + vision_end_token_id=2, + tie_word_embeddings=True, + vision_config=Qwen3VLMoeVisionConfig( + depth=4, + hidden_size=256, + hidden_act="gelu_pytorch_tanh", + intermediate_size=512, + num_heads=4, + in_channels=3, + patch_size=16, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=512, + num_position_embeddings=256, + deepstack_visual_indexes=[1, 2, 3], + initializer_range=0.02, + ).to_dict(), + text_config=Qwen3VLMoeTextConfig( + vocab_size=32000, + hidden_size=512, + intermediate_size=2048, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + head_dim=64, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=False, + tie_word_embeddings=True, + attention_dropout=0.0, + attention_bias=False, + decoder_sparse_step=1, + moe_intermediate_size=1024, + num_experts_per_tok=2, + num_experts=4, + mlp_only_layers=[], + ).to_dict(), + ), + ) + +if QWEN3_5_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3_5"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_5, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_5, + model_class=Qwen3_5ForConditionalGeneration, + mini_model_config=Qwen3_5Config( + attn_implementation="sdpa", + image_token_id=4, + video_token_id=5, + vision_start_token_id=1, + vision_end_token_id=2, + tie_word_embeddings=True, + vision_config=Qwen3_5VisionConfig( + depth=4, + hidden_size=256, + hidden_act="gelu_pytorch_tanh", + intermediate_size=512, + num_heads=4, + in_channels=3, + patch_size=16, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=512, + num_position_embeddings=256, + deepstack_visual_indexes=[1, 2, 3], + initializer_range=0.02, + ).to_dict(), + text_config=Qwen3_5TextConfig( + vocab_size=32000, + hidden_size=512, + intermediate_size=2048, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + head_dim=64, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=False, + tie_word_embeddings=True, + rope_scaling=dict( + type="mrope", + mrope_section=[16, 24, 24], # (temporal, height, width) + ) + if not IS_TRANSFORMERS_V5_OR_LATER + else None, + attention_dropout=0.0, + attention_bias=False, + decoder_sparse_step=1, + moe_intermediate_size=1024, + num_experts_per_tok=2, + num_experts=4, + mlp_only_layers=[], + pad_token_id=None, + ).to_dict(), + ), + ) + +if PIXTRAL_AVAILABLE: + MINI_MODEL_SETUPS["mini_pixtral"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_pixtral, + liger_kernel_patch_revert_func=revert_liger_kernel_to_pixtral, + model_class=PixtralVisionModel, + mini_model_config=PixtralVisionConfig( + hidden_size=1024, + intermediate_size=2048, + num_hidden_layers=4, + num_attention_heads=8, + num_channels=3, + image_size=256, + patch_size=16, + hidden_act="silu", + attention_dropout=0.0, + rope_theta=10000.0, + initializer_range=0.02, + ), + ) + + +def create_processor(model_name: str): + if model_name == "mini_qwen2_vl": + tokenizer_config = load_tokenizer_config( + os.path.join(FAKE_CONFIGS_PATH, "Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json") + ) + tokenizer_base = train_bpe_tokenizer( + [ + token.content + for key, token in sorted( + tokenizer_config["added_tokens_decoder"].items(), + key=lambda x: int(x[0]), + ) + ] + ) + qwen_tokenizer = Qwen2Tokenizer(tokenizer_object=tokenizer_base, **tokenizer_config) + image_processor = Qwen2VLImageProcessor() + video_processor = Qwen2VLVideoProcessor() + return Qwen2VLProcessor( + image_processor=image_processor, + video_processor=video_processor, + tokenizer=qwen_tokenizer, + ) + + elif model_name == "mini_qwen2_5_vl": + tokenizer_config = load_tokenizer_config( + os.path.join(FAKE_CONFIGS_PATH, "Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json") + ) + tokenizer_base = train_bpe_tokenizer( + [ + token.content + for key, token in sorted( + tokenizer_config["added_tokens_decoder"].items(), + key=lambda x: int(x[0]), + ) + ] + ) + qwen_tokenizer = Qwen2Tokenizer(tokenizer_object=tokenizer_base, **tokenizer_config) + image_processor = Qwen2VLImageProcessor() + video_processor = Qwen2VLVideoProcessor() + return Qwen2_5_VLProcessor( + image_processor=image_processor, + video_processor=video_processor, + tokenizer=qwen_tokenizer, + ) + + elif model_name in ("mini_qwen3_vl", "mini_qwen3_vl_moe", "mini_qwen3_5"): + tokenizer_config = load_tokenizer_config( + os.path.join(FAKE_CONFIGS_PATH, "Qwen/Qwen3-VL-4B-Instruct/tokenizer_config.json") + ) + tokenizer_base = train_bpe_tokenizer( + [ + token.content + for key, token in sorted( + tokenizer_config["added_tokens_decoder"].items(), + key=lambda x: int(x[0]), + ) + ] + ) + qwen_tokenizer = Qwen2Tokenizer(tokenizer_object=tokenizer_base, **tokenizer_config) + image_processor = Qwen2VLImageProcessor(patch_size=16, temporal_patch_size=2, merge_size=2) + video_processor = Qwen3VLVideoProcessor() + return Qwen3VLProcessor( + image_processor=image_processor, + video_processor=video_processor, + tokenizer=qwen_tokenizer, + ) + + elif model_name == "mini_llava": + tokenizer_config = load_tokenizer_config( + os.path.join( + FAKE_CONFIGS_PATH, + "Llava/llava-1.5-7b-hf/tokenizer_config.json", + ) + ) + image_processor_config = load_image_processing_config( + os.path.join( + FAKE_CONFIGS_PATH, + "Llava/llava-1.5-7b-hf/preprocessor_config.json", + ) + ) + processor_config = load_processor_config( + os.path.join( + FAKE_CONFIGS_PATH, + "Llava/llava-1.5-7b-hf/processor_config.json", + ) + ) + tokenizer_base = train_bpe_tokenizer( + [ + token.content + for key, token in sorted( + tokenizer_config["added_tokens_decoder"].items(), + key=lambda x: int(x[0]), + ) + ] + ) + + fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer_base, **tokenizer_config) + fast_tokenizer.model_input_names = ["input_ids", "attention_mask"] + image_processor = CLIPImageProcessor(**image_processor_config) + + return LlavaProcessor(**processor_config, image_processor=image_processor, tokenizer=fast_tokenizer) + + elif model_name == "mini_internvl": + tokenizer_config = load_tokenizer_config( + os.path.join(FAKE_CONFIGS_PATH, "OpenGVLab/InternVL3-1B-hf/tokenizer_config.json") + ) + tokenizer_base = train_bpe_tokenizer( + [ + token.content + for key, token in sorted( + tokenizer_config["added_tokens_decoder"].items(), + key=lambda x: int(x[0]), + ) + ] + ) + qwen_tokenizer = Qwen2Tokenizer(tokenizer_object=tokenizer_base, **tokenizer_config) + image_processor = GotOcr2ImageProcessorFast( + crop_to_patches=False, min_patches=1, max_patches=12, size={"height": 448, "width": 448} + ) + video_processor = InternVLVideoProcessor() + + # Return proper InternVL processor + return InternVLProcessor( + image_processor=image_processor, tokenizer=qwen_tokenizer, video_processor=video_processor + ) + + elif model_name == "mini_smolvlm2": + tokenizer_config = load_tokenizer_config( + os.path.join(FAKE_CONFIGS_PATH, "HuggingFaceTB/SmolVLM2-256M-Video-Instruct/tokenizer_config.json") + ) + tokenizer_base = train_bpe_tokenizer( + [ + token.content + for key, token in sorted( + tokenizer_config["added_tokens_decoder"].items(), + key=lambda x: int(x[0]), + ) + ] + ) + gpt2_tokenizer = GPT2Tokenizer(tokenizer_object=tokenizer_base, **tokenizer_config) + image_processor = SmolVLMImageProcessor(size={"longest_edge": 512}) + video_processor = SmolVLMVideoProcessor() + + # Return proper SmolVLM processor + return SmolVLMProcessor( + image_processor=image_processor, tokenizer=gpt2_tokenizer, video_processor=video_processor + ) + + elif model_name.startswith("mini_llama4"): + tokenizer_config = load_tokenizer_config( + os.path.join( + FAKE_CONFIGS_PATH, + "meta-llama/Llama-4-Scout-17B-16E-Instruct/tokenizer_config.json", + ) + ) + tokenizer_base = train_bpe_tokenizer( + [ + token.content + for key, token in sorted( + tokenizer_config["added_tokens_decoder"].items(), + key=lambda x: int(x[0]), + ) + ] + ) + fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer_base, **tokenizer_config) + image_processor = Llama4ImageProcessorFast(size={"height": 560, "width": 560}) + return Llama4Processor( + image_processor=image_processor, + tokenizer=fast_tokenizer, + fake_image_token="<|image|>", + image_token="<|image|>", + ) + + elif model_name == "mini_mllama": + tokenizer_config = load_tokenizer_config( + os.path.join( + FAKE_CONFIGS_PATH, + "meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json", + ) + ) + tokenizer_base = train_bpe_tokenizer( + [ + token.content + for key, token in sorted( + tokenizer_config["added_tokens_decoder"].items(), + key=lambda x: int(x[0]), + ) + ] + ) + fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer_base, **tokenizer_config) + image_processor = MllamaImageProcessor(size={"height": 560, "width": 560}) + return MllamaProcessor(image_processor=image_processor, tokenizer=fast_tokenizer) + + elif model_name.startswith("mini_paligemma"): + tokenizer_config = load_tokenizer_config( + os.path.join( + FAKE_CONFIGS_PATH, + "Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json", + ) + ) + tokenizer_base = train_bpe_tokenizer( + [ + token.content + for key, token in sorted( + tokenizer_config["added_tokens_decoder"].items(), + key=lambda x: int(x[0]), + ) + ] + ) + fast_tokenizer = GemmaTokenizer(tokenizer_object=tokenizer_base, **tokenizer_config) + image_processor = SiglipImageProcessor(size={"height": 224, "width": 224}, image_seq_length=256) + return PaliGemmaProcessor(image_processor=image_processor, tokenizer=fast_tokenizer) + + elif model_name.startswith("mini_gemma3"): + tokenizer_config = load_tokenizer_config( + os.path.join( + FAKE_CONFIGS_PATH, + "Google/Gemma3/gemma-3-4b-it/tokenizer_config.json", + ) + ) + tokenizer_base = train_bpe_tokenizer( + [ + token.content + for key, token in sorted( + tokenizer_config["added_tokens_decoder"].items(), + key=lambda x: int(x[0]), + ) + ] + ) + fast_tokenizer = GemmaTokenizer(tokenizer_object=tokenizer_base, **tokenizer_config) + image_processor = Gemma3ImageProcessor() + return Gemma3Processor(image_processor=image_processor, tokenizer=fast_tokenizer) + + else: + raise ValueError(f"Processor not available for model {model_name}") + + +def create_multimodal_dataset(model_name: str): + processor = create_processor(model_name) + + def generate_procedural_image(example, index): + """Generate an image with a single row of white pixels at the index specified""" + image = torch.zeros(3, TEST_IMAGE_DIM, TEST_IMAGE_DIM) + image[:, index % TEST_IMAGE_DIM, :] = 255 + example["image"] = image + return example + + def apply_chat_template(example): + """ + Under the hood, this inserts the correct image placeholder token into the text. + More or less this conversation format is used by HF's mllms. The fact that it is + formatting as for IFT is not in-and-of-itself important here. + """ + conversation = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "Describe this image."}, + ], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": example["text"]}], + }, + ] + example["text"] = processor.tokenizer.apply_chat_template(conversation, tokenize=False) + return example + + def preprocess_function(examples): + """Tokenize text, preprocess images, and generate other relevant inputs for the model.""" + if model_name == "mini_llama4": + # Process images and text separately to avoid complex token replacement, this helped setting lower tolerance than processing them together. + image_inputs = processor.image_processor(images=examples["image"], return_tensors="pt") + text_inputs = processor.tokenizer( + examples["text"], + padding="max_length", + truncation=True, + max_length=1024, + return_tensors="pt", + ) + return {**text_inputs, **image_inputs} + else: + # For other models, use the normal processor + return processor( + text=examples["text"], + images=examples["image"], + padding="max_length", + truncation=True, + max_length=1024, # longer than for text-only b/c images require quite a few tokens + return_tensors="pt", + ) + + train_dataset = ( + load_dataset("text", data_files={"train": UNTOKENIZED_DATASET_PATH}, split="train") + .to_iterable_dataset() # only map examples as-needed and on-demand + .map(generate_procedural_image, with_indices=True) + .map(apply_chat_template) + .map(preprocess_function, remove_columns=["text", "image"]) + ) + return train_dataset + + +def create_model(model_name): + """ + Create a mini version model + The commented values are the original values + """ + model_config = MINI_MODEL_SETUPS[model_name].mini_model_config + model_class = MINI_MODEL_SETUPS[model_name].model_class + return model_class(model_config) + + +@require_deterministic +def run_mini_model_multimodal( + model_name="mini_qwen2_vl", + num_steps=100, + dtype=torch.bfloat16, + lr=1e-5, + with_liger=False, +): + # If we move it to the beginning of test_mini_model, the two runs are initialized with different weights. + # This is due to RNG (Random Number Generator). The formula of RNG progression is x_(n+1) = (a * x_n + c) % m + # Everytime RNG is used, like randomly initialzing weight, the RNG progresses to the next state. + # Therefore, we have to reset RNG before we create the model to ensure the weight initialization started from the same RNG state. + + set_seed(42) + + revert_kwargs = {"model_config": MINI_MODEL_SETUPS[model_name]} + if "mllama" in model_name or "llama4" in model_name or "qwen3_5" in model_name: + revert_kwargs["model_type"] = "conditional_generation" + + if with_liger is True: + kwargs = { + "rope": True, + "rms_norm": True, + "cross_entropy": False, + } + if "llama4" in model_name: + kwargs["rope"] = False + if ( + "qwen2_5_vl" not in model_name + and "llava" not in model_name + and "qwen3_vl" not in model_name + and "qwen3_5" not in model_name + ): + kwargs["layer_norm"] = True + + if "qwen3_5" in model_name: + kwargs["rope"] = False + + if "gemma" in model_name: + kwargs["geglu"] = True + else: + kwargs["swiglu"] = True + + if "llava" in model_name: + apply_liger_kernel_to_llama(**kwargs) + + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) + + else: + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs) + + model = create_model(model_name).to(dtype).to(device) + + model.gradient_checkpointing_enable() + + train_dataset = create_multimodal_dataset(model_name) + loader = DataLoader(train_dataset, batch_size=2, shuffle=False, collate_fn=multimodal_collate_fn) + loader_iter = iter(loader) + optimizer = torch.optim.AdamW(model.parameters(), lr=lr) + + loss_list = [] + + for i in range(num_steps): + batch = next(loader_iter).to(model.device) + optimizer.zero_grad() + output = model(**batch) + output.loss.backward() + optimizer.step() + + print(f"Step {i}, Loss: {output.loss.item()}") + loss_list.append(output.loss.item()) + + model.eval() + eval_batch = next(loader_iter).to(model.device) + if with_liger: + eval_batch["skip_logits"] = False + with torch.no_grad(): + eval_output = model(**eval_batch) + print(f"Eval Loss: {eval_output.loss.item()}") + loss_list.append(eval_output.loss.item()) + topk_logprobs = get_topk(get_logprobs(eval_output.logits)) + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs) + return { + "loss": loss_list, + "topk_logprobs": topk_logprobs.values, + "model": model, + } + + +@pytest.mark.parametrize( + "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logprobs_atol, logprobs_rtol, param_atol, param_rtol", + [ + pytest.param( + "mini_qwen2_vl", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=[ + pytest.mark.skipif( + not QWEN2_VL_AVAILABLE, + reason="Qwen2-VL not available in this version of transformers", + ), + pytest.mark.skipif(not is_torchvision_available(), reason="Qwen2VLVideoProcessor requires torchvision"), + ], + ), + # Disable since Llama4 image processor resacle and normalize images to torch.bfloat16, the dtype of model parameters have to be bfloat16 + # Refer to: https://github.com/huggingface/transformers/blob/67ddc82fbc7e52c6f42a395b4a6d278c55b77a39/src/transformers/models/llama4/image_processing_llama4_fast.py#L371 + pytest.param( + "mini_llama4", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif( + not LLAMA4_AVAILABLE, + reason="Llama4 not available in this version of transformers", + ), + pytest.mark.xfail( + reason=( + "RuntimeError: Expected query, key, and value to have the same dtype, but got query.dtype:" + " float key.dtype: c10::BFloat16 and value.dtype: c10::BFloat16 instead." + ) + ), + ], + ), + pytest.param( + "mini_llava", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=[ + pytest.mark.skipif( + not LLAVA_AVAILABLE, + reason="LLaVa not available in this version of transformers", + ), + pytest.mark.skipif( + True, + reason="Flaky test", + ), + ], + ), + pytest.param( + "mini_internvl", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not INTERNVL_AVAILABLE, + reason="InternVL not available in this version of transformers", + ), + ), + pytest.param( + "mini_smolvlm2", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=[ + pytest.mark.skipif( + not SMOLVLM2_AVAILABLE, + reason="SmolVLM2 not available in this version of transformers", + ), + pytest.mark.skipif( + not NUM2WORDS_AVAILABLE, + reason="num2words must be present to run SmolVLMProcessor", + ), + ], + ), + pytest.param( + "mini_qwen2_5_vl", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=[ + pytest.mark.skipif( + not QWEN2_5_VL_AVAILABLE, + reason="Qwen2.5-VL not available in this version of transformers", + ), + pytest.mark.skipif(not is_torchvision_available(), reason="Qwen2VLVideoProcessor requires torchvision"), + ], + ), + pytest.param( + "mini_qwen3_vl", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=[ + pytest.mark.skipif( + not QWEN3_VL_AVAILABLE, + reason="Qwen3-VL not available in this version of transformers", + ), + pytest.mark.skipif( + not is_torchvision_available(), + reason="Qwen3VLVideoProcessor requires torchvision", + ), + pytest.mark.skipif( + True, + reason="Flaky test", + ), + ], + ), + pytest.param( + "mini_qwen3_vl_moe", + 32, + 1e-4, + torch.float32, + 1e-7, + 5e-4, + 5e-2, + 5e-3, + 5e-3, + 1e-5, + marks=[ + pytest.mark.skipif( + not QWEN3_VL_MOE_AVAILABLE, + reason="Qwen3-VL-MoE not available in this version of transformers", + ), + pytest.mark.skipif( + not is_torchvision_available(), + reason="Qwen3VLVideoProcessor requires torchvision", + ), + pytest.mark.skipif( + True, + reason="Flaky test", + ), + ], + ), + pytest.param( + "mini_mllama", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=[ + pytest.mark.skipif( + not MLLAMA_AVAILABLE, + reason="Mllama not available in this version of transformers", + ), + pytest.mark.skipif( + version.parse("4.51.0") > version.parse(transformers.__version__), + reason="MllamaForConditionalGeneration doesn't accecpt `skip_logits` kwargs", + ), + ], + ), + pytest.param( + "mini_paligemma", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not PALIGEMMA_AVAILABLE, + reason="Paligemma not available in this version of transformers", + ), + ), + pytest.param( + "mini_paligemma2", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not PALIGEMMA_AVAILABLE, + reason="Paligemma2 not available in this version of transformers", + ), + ), + pytest.param( + "mini_gemma3", + 32, + 1e-5, + torch.float32, + 1e-8, + 1e-4, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=[ + pytest.mark.skipif( + not GEMMA3_AVAILABLE, + reason="Gemma3 not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_qwen3_5", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=[ + pytest.mark.skipif( + not QWEN3_5_AVAILABLE, + reason="Qwen3.5 not available in this version of transformers", + ), + pytest.mark.skipif( + not is_torchvision_available(), + reason="Qwen3VLVideoProcessor requires torchvision", + ), + ], + ), + ], +) +def test_mini_model_multimodal( + model_name, + num_steps, + lr, + dtype, + loss_atol, + loss_rtol, + logprobs_atol, + logprobs_rtol, + param_atol, + param_rtol, +): + # Non-liger models should be initialized and tested first to avoid the module being overridden + expected_output = run_mini_model_multimodal(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr) + + actual_output = run_mini_model_multimodal( + model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True + ) + + # Compare the loss of every step + assert_verbose_allclose( + torch.tensor([expected_output["loss"]]), + torch.tensor([actual_output["loss"]]), + atol=loss_atol, + rtol=loss_rtol, + extra_info="[Loss]", + ) + + # Compare the logits from the last step + assert_verbose_allclose( + expected_output["topk_logprobs"], + actual_output["topk_logprobs"], + atol=logprobs_atol, + rtol=logprobs_rtol, + extra_info="[Top k logrpobs]", + ) + + # Compare the params from the last step + # Iterate over the model's parameters and compare them + for expected_param, actual_param in zip( + expected_output["model"].named_parameters(), + actual_output["model"].named_parameters(), + ): + assert_verbose_allclose( + expected_param[1], + actual_param[1], + atol=param_atol, + rtol=param_rtol, + extra_info="[Model parameters]", + ) + + +# +# Vision-only model tests (e.g. Pixtral vision encoder) +# + + +def generate_procedural_pixel_values(batch_size, num_channels, image_size, index, dtype, device): + """Generate deterministic pixel values for vision-only model testing. + + Each image has a single row of white pixels at a deterministic position, + providing a reproducible signal for convergence testing. + """ + pixel_values = torch.zeros(batch_size, num_channels, image_size, image_size, dtype=dtype, device=device) + for b in range(batch_size): + row = (index + b) % image_size + pixel_values[b, :, row, :] = 1.0 + return pixel_values + + +@require_deterministic +def run_mini_model_vision( + model_name="mini_pixtral", + num_steps=100, + dtype=torch.float32, + lr=1e-5, + with_liger=False, +): + set_seed(42) + + revert_kwargs = {"model_config": MINI_MODEL_SETUPS[model_name]} + + if with_liger is True: + kwargs = { + "rope": True, + "rms_norm": True, + "swiglu": True, + } + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) + else: + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs) + + model = create_model(model_name).to(dtype).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=lr) + + loss_list = [] + + for i in range(num_steps): + optimizer.zero_grad() + pixel_values = generate_procedural_pixel_values( + batch_size=2, + num_channels=model.config.num_channels, + image_size=model.config.image_size, + index=i, + dtype=dtype, + device=device, + ) + output = model(pixel_values=pixel_values) + loss = output.last_hidden_state.sum() + loss.backward() + optimizer.step() + print(f"Step {i}, Loss: {loss.item()}") + loss_list.append(loss.item()) + + # Eval step with deterministic input + model.eval() + with torch.no_grad(): + eval_pixel_values = generate_procedural_pixel_values( + batch_size=2, + num_channels=model.config.num_channels, + image_size=model.config.image_size, + index=num_steps, + dtype=dtype, + device=device, + ) + eval_output = model(pixel_values=eval_pixel_values) + + topk_logprobs = get_topk(get_logprobs(eval_output.last_hidden_state)) + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs) + return { + "loss": loss_list, + "topk_logprobs": topk_logprobs.values, + "model": model, + } + + +@pytest.mark.parametrize( + "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logprobs_atol, logprobs_rtol, param_atol, param_rtol", + [ + pytest.param( + "mini_pixtral", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=[ + pytest.mark.skipif( + not PIXTRAL_AVAILABLE, reason="Pixtral not available in this version of transformers" + ), + ], + ), + ], +) +def test_mini_model_vision( + model_name, + num_steps, + lr, + dtype, + loss_atol, + loss_rtol, + logprobs_atol, + logprobs_rtol, + param_atol, + param_rtol, +): + # Non-liger models should be initialized and tested first to avoid the module being overridden + expected_output = run_mini_model_vision(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr) + + actual_output = run_mini_model_vision( + model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True + ) + + # Compare the loss of every step + assert_verbose_allclose( + torch.tensor([expected_output["loss"]]), + torch.tensor([actual_output["loss"]]), + atol=loss_atol, + rtol=loss_rtol, + extra_info="[Loss]", + ) + + # Compare the topk logprobs from evaluation step + assert_verbose_allclose( + expected_output["topk_logprobs"], + actual_output["topk_logprobs"], + atol=logprobs_atol, + rtol=logprobs_rtol, + extra_info="[Top k logprobs]", + ) + + # Compare the params from the last step + for expected_param, actual_param in zip( + expected_output["model"].named_parameters(), + actual_output["model"].named_parameters(), + ): + assert_verbose_allclose( + expected_param[1], + actual_param[1], + atol=param_atol, + rtol=param_rtol, + extra_info="[Model parameters]", + ) diff --git a/test/convergence/fp32/test_mini_models_with_logits.py b/test/convergence/fp32/test_mini_models_with_logits.py new file mode 100755 index 0000000000000000000000000000000000000000..d225e08bafa4a1f99ffb6f6840cfe55acf77b163 --- /dev/null +++ b/test/convergence/fp32/test_mini_models_with_logits.py @@ -0,0 +1,2038 @@ +import os + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # Ensure deterministic behavior with CuBLAS + +import pytest +import torch +import transformers + +from datasets import load_from_disk +from packaging import version +from torch.utils.data import DataLoader +from transformers.models.gemma import GemmaConfig +from transformers.models.gemma import GemmaForCausalLM +from transformers.models.gemma2 import Gemma2Config +from transformers.models.gemma2 import Gemma2ForCausalLM +from transformers.models.llama import LlamaConfig +from transformers.models.llama import LlamaForCausalLM +from transformers.models.mistral import MistralConfig +from transformers.models.mistral import MistralForCausalLM +from transformers.models.mixtral import MixtralConfig +from transformers.models.mixtral import MixtralForCausalLM +from transformers.models.phi3 import Phi3Config +from transformers.models.phi3 import Phi3ForCausalLM +from transformers.models.qwen2 import Qwen2Config +from transformers.models.qwen2 import Qwen2ForCausalLM + +from liger_kernel.transformers import apply_liger_kernel_to_exaone4 +from liger_kernel.transformers import apply_liger_kernel_to_falcon_h1 +from liger_kernel.transformers import apply_liger_kernel_to_gemma +from liger_kernel.transformers import apply_liger_kernel_to_gemma2 +from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text +from liger_kernel.transformers import apply_liger_kernel_to_glm4 +from liger_kernel.transformers import apply_liger_kernel_to_glm4v +from liger_kernel.transformers import apply_liger_kernel_to_glm4v_moe +from liger_kernel.transformers import apply_liger_kernel_to_granite +from liger_kernel.transformers import apply_liger_kernel_to_hunyuan_v1_dense +from liger_kernel.transformers import apply_liger_kernel_to_hunyuan_v1_moe +from liger_kernel.transformers import apply_liger_kernel_to_internvl +from liger_kernel.transformers import apply_liger_kernel_to_llama +from liger_kernel.transformers import apply_liger_kernel_to_llama4 +from liger_kernel.transformers import apply_liger_kernel_to_llava +from liger_kernel.transformers import apply_liger_kernel_to_mistral +from liger_kernel.transformers import apply_liger_kernel_to_mixtral +from liger_kernel.transformers import apply_liger_kernel_to_mllama +from liger_kernel.transformers import apply_liger_kernel_to_olmo2 +from liger_kernel.transformers import apply_liger_kernel_to_olmo3 +from liger_kernel.transformers import apply_liger_kernel_to_phi3 +from liger_kernel.transformers import apply_liger_kernel_to_qwen2 +from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl +from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl +from liger_kernel.transformers import apply_liger_kernel_to_qwen3 +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5 +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_moe +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_next +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl +from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl_moe +from liger_kernel.transformers import apply_liger_kernel_to_smollm3 +from liger_kernel.utils import infer_device +from test.utils import DEFAULT_DATASET_PATH +from test.utils import MiniModelConfig +from test.utils import assert_verbose_allclose +from test.utils import get_logprobs +from test.utils import get_topk +from test.utils import require_deterministic +from test.utils import revert_liger_kernel_to_exaone4 +from test.utils import revert_liger_kernel_to_falcon_h1 +from test.utils import revert_liger_kernel_to_gemma +from test.utils import revert_liger_kernel_to_gemma2 +from test.utils import revert_liger_kernel_to_gemma3_text +from test.utils import revert_liger_kernel_to_glm4 +from test.utils import revert_liger_kernel_to_glm4v +from test.utils import revert_liger_kernel_to_glm4v_moe +from test.utils import revert_liger_kernel_to_granite +from test.utils import revert_liger_kernel_to_hunyuan_v1 +from test.utils import revert_liger_kernel_to_hunyuan_v1_moe +from test.utils import revert_liger_kernel_to_internvl +from test.utils import revert_liger_kernel_to_llama +from test.utils import revert_liger_kernel_to_llama4 +from test.utils import revert_liger_kernel_to_llava +from test.utils import revert_liger_kernel_to_mistral +from test.utils import revert_liger_kernel_to_mixtral +from test.utils import revert_liger_kernel_to_mllama +from test.utils import revert_liger_kernel_to_olmo2 +from test.utils import revert_liger_kernel_to_olmo3 +from test.utils import revert_liger_kernel_to_phi3 +from test.utils import revert_liger_kernel_to_qwen2 +from test.utils import revert_liger_kernel_to_qwen2_5_vl +from test.utils import revert_liger_kernel_to_qwen2_vl +from test.utils import revert_liger_kernel_to_qwen3 +from test.utils import revert_liger_kernel_to_qwen3_5 +from test.utils import revert_liger_kernel_to_qwen3_moe +from test.utils import revert_liger_kernel_to_qwen3_next +from test.utils import revert_liger_kernel_to_qwen3_vl +from test.utils import revert_liger_kernel_to_qwen3_vl_moe +from test.utils import revert_liger_kernel_to_smollm3 +from test.utils import set_seed +from test.utils import simple_collate_fn + +IS_TRANSFORMERS_V5_OR_LATER = version.parse(transformers.__version__) >= version.parse("5.0.0") + +try: + from transformers.models.llama4.configuration_llama4 import Llama4TextConfig + from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM + + LLAMA4_AVAILABLE = True +except ImportError: + LLAMA4_AVAILABLE = False + +try: + # Mllama is only available in transformers>=4.45.0 + from transformers.models.mllama.configuration_mllama import MllamaTextConfig + from transformers.models.mllama.modeling_mllama import MllamaForCausalLM + + MLLAMA_AVAILABLE = True +except ImportError: + MLLAMA_AVAILABLE = False + +try: + # Qwen2-VL is only available in transformers>4.52.4 + import transformers + + from packaging import version + from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration + + QWEN2_VL_AVAILABLE = version.parse(transformers.__version__) >= version.parse("4.52.4") +except ImportError: + QWEN2_VL_AVAILABLE = False + +try: + # Qwen2.5-VL is only available in transformers>4.52.4 + import transformers + + from packaging import version + from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration + + QWEN2_5_VL_AVAILABLE = version.parse(transformers.__version__) >= version.parse("4.52.4") +except ImportError: + QWEN2_5_VL_AVAILABLE = False + + +try: + import transformers + + from packaging import version + from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration + + QWEN3_VL_AVAILABLE = version.parse(transformers.__version__) >= version.parse("4.57.0") +except ImportError: + QWEN3_VL_AVAILABLE = False + + +try: + import transformers + + from packaging import version + from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeConfig + from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeTextConfig + from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeVisionConfig + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration + + QWEN3_VL_MOE_AVAILABLE = version.parse(transformers.__version__) >= version.parse("4.57.0") +except ImportError: + QWEN3_VL_MOE_AVAILABLE = False + +try: + from transformers.models.qwen3.configuration_qwen3 import Qwen3Config + from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM + from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig + from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeForCausalLM + + QWEN3_AVAILABLE = True +except ImportError: + QWEN3_AVAILABLE = False + +try: + from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration + + QWEN3_VL_AVAILABLE = True +except ImportError: + QWEN3_VL_AVAILABLE = False + +try: + from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeConfig + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration + + QWEN3_VL_MOE_AVAILABLE = True +except ImportError: + QWEN3_VL_MOE_AVAILABLE = False + +try: + from transformers.models.granite import GraniteConfig + from transformers.models.granite import GraniteForCausalLM + + GRANITE_AVAILABLE = True +except ImportError: + GRANITE_AVAILABLE = False + +try: + from transformers import CLIPVisionConfig + from transformers.models.llava.configuration_llava import LlavaConfig + from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration + + LLAVA_AVAILABLE = True +except ImportError: + LLAVA_AVAILABLE = False + +try: + # OLMO2 is only available in transformers>=4.47.0 + from transformers.models.olmo2.configuration_olmo2 import Olmo2Config + from transformers.models.olmo2.modeling_olmo2 import Olmo2ForCausalLM + + OLMO2_AVAILABLE = True +except ImportError: + OLMO2_AVAILABLE = False + +try: + # OLMO3 is only available in transformers>=4.57.0 + from transformers.models.olmo3.configuration_olmo3 import Olmo3Config + from transformers.models.olmo3.modeling_olmo3 import Olmo3ForCausalLM + + OLMO3_AVAILABLE = True +except ImportError: + OLMO3_AVAILABLE = False + +try: + # Glm4 is only available in transformers>=4.51.3 + from transformers.models.glm4.configuration_glm4 import Glm4Config + from transformers.models.glm4.modeling_glm4 import Glm4ForCausalLM + + GLM4_AVAILABLE = True +except ImportError: + GLM4_AVAILABLE = False + +try: + # Glm4v is only available in transformers>=4.51.3 + from transformers.models.glm4v.configuration_glm4v import Glm4vConfig + from transformers.models.glm4v.modeling_glm4v import Glm4vForConditionalGeneration + + GLM4V_AVAILABLE = True +except ImportError: + GLM4V_AVAILABLE = False + +try: + # Glm4v_moe is only available in transformers>=4.51.3 + from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig + from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeForConditionalGeneration + + GLM4V_MOE_AVAILABLE = True +except ImportError: + GLM4V_MOE_AVAILABLE = False + +try: + from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig + from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM + + GEMMA3_AVAILABLE = True +except ImportError: + GEMMA3_AVAILABLE = False + +try: + # Smollm3 is only available in transformers>=4.53.0 + from transformers.models.smollm3.configuration_smollm3 import SmolLM3Config + from transformers.models.smollm3.modeling_smollm3 import SmolLM3ForCausalLM + + SMOLLM3_AVAILABLE = True +except ImportError: + SMOLLM3_AVAILABLE = False + +try: + # InternVL is only available in transformers>=4.52.1 + from transformers.models.internvl.configuration_internvl import InternVLConfig + from transformers.models.internvl.modeling_internvl import InternVLForConditionalGeneration + + INTERNVL_AVAILABLE = True +except ImportError: + INTERNVL_AVAILABLE = False + +try: + # FalconH1 is only available in transformers>=4.53.0 + from transformers.models.falcon_h1.configuration_falcon_h1 import FalconH1Config + from transformers.models.falcon_h1.modeling_falcon_h1 import FalconH1ForCausalLM + + FALCONH1_AVAILABLE = True +except ImportError: + FALCONH1_AVAILABLE = False + +try: + # Qwen3Next is only available in transformers>=4.57.0 + from transformers.models.qwen3_next.configuration_qwen3_next import Qwen3NextConfig + from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextForCausalLM + + QWEN3NEXT_AVAILABLE = True +except ImportError: + QWEN3NEXT_AVAILABLE = False + +try: + from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM + + QWEN3_5_AVAILABLE = True +except ImportError: + QWEN3_5_AVAILABLE = False + +try: + from transformers.models.hunyuan_v1_dense.configuration_hunyuan_v1_dense import HunYuanDenseV1Config + from transformers.models.hunyuan_v1_dense.modeling_hunyuan_v1_dense import HunYuanDenseV1ForCausalLM + from transformers.models.hunyuan_v1_moe.configuration_hunyuan_v1_moe import HunYuanMoEV1Config + from transformers.models.hunyuan_v1_moe.modeling_hunyuan_v1_moe import HunYuanMoEV1ForCausalLM + + HUNYUAN_V1_AVAILABLE = True +except ImportError: + HUNYUAN_V1_AVAILABLE = False + +try: + from transformers.models.exaone4.configuration_exaone4 import Exaone4Config + from transformers.models.exaone4.modeling_exaone4 import Exaone4ForCausalLM + + EXAONE4_AVAILABLE = True +except ImportError: + EXAONE4_AVAILABLE = False + + +device = infer_device() + +MINI_MODEL_SETUPS = { + "mini_llama3": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_llama, + liger_kernel_patch_revert_func=revert_liger_kernel_to_llama, + model_class=LlamaForCausalLM, + mini_model_config=LlamaConfig( + attention_bias=False, + attention_dropout=0.0, + # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=8192, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + num_key_value_heads=2, # 8 + pretraining_tp=1, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ), + "mini_qwen2": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen2, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen2, + model_class=Qwen2ForCausalLM, + mini_model_config=Qwen2Config( + attention_dropout=0.0, + bos_token_id=1, # 151643 + eos_token_id=2, # 151643 + hidden_act="silu", + hidden_size=896, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=32768, # 131072 + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-6, + sliding_window=131072, + tie_word_embeddings=True, + use_cache=True, + vocab_size=32000, # 151936 + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ), + "mini_phi3": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_phi3, + liger_kernel_patch_revert_func=revert_liger_kernel_to_phi3, + model_class=Phi3ForCausalLM, + mini_model_config=Phi3Config( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, # 32000 + hidden_act="silu", + hidden_size=896, # 3072 + initializer_range=0.02, + intermediate_size=4864, # 8192 + max_position_embeddings=4096, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + num_key_value_heads=None, # defaults to num_attention_heads + rms_norm_eps=1e-5, + sliding_window=None, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32064, + attn_implementation="eager", + ), + ), + "mini_mistral": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_mistral, + liger_kernel_patch_revert_func=revert_liger_kernel_to_mistral, + model_class=MistralForCausalLM, + mini_model_config=MistralConfig( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=2048, + max_position_embeddings=32768, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-5, + sliding_window=4096, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, + attn_implementation="sdpa", + ), + ), + "mini_mixtral": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_mixtral, + liger_kernel_patch_revert_func=revert_liger_kernel_to_mixtral, + model_class=MixtralForCausalLM, + mini_model_config=MixtralConfig( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=512, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=32768, # 32768 + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-5, + sliding_window=4096, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, + attn_implementation="sdpa", + ), + ), + "mini_gemma1": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_gemma, + liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma, + model_class=GemmaForCausalLM, + mini_model_config=GemmaConfig( + vocab_size=32000, # 256000 + hidden_size=1024, # 3072 + intermediate_size=2048, # 24576 + num_hidden_layers=4, # 28 + num_attention_heads=4, # 16 + num_key_value_heads=4, # 16 + head_dim=256, + # gemma1 model config uses `hidden_act` and point it to gelu, + # https://huggingface.co/google/gemma-7b/blob/main/config.json#L10 + # but in reality it's ignored and HuggingFace will use tanh approximation: + # https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/gemma/modeling_gemma.py#L175 + hidden_act="gelu", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + pad_token_id=0, + # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + tie_word_embeddings=True, + attention_bias=False, + attention_dropout=0.0, + ), + ), + "mini_gemma1.1": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_gemma, + liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma, + model_class=GemmaForCausalLM, + mini_model_config=GemmaConfig( + vocab_size=32000, # 256000 + hidden_size=1024, # 3072 + intermediate_size=2048, # 24576 + num_hidden_layers=4, # 28 + num_attention_heads=4, # 16 + num_key_value_heads=4, # 16 + head_dim=256, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + pad_token_id=0, + # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + tie_word_embeddings=True, + attention_bias=False, + attention_dropout=0.0, + ), + ), + "mini_gemma2": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_gemma2, + liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma2, + model_class=Gemma2ForCausalLM, + mini_model_config=Gemma2Config( + vocab_size=32000, # 256000 + hidden_size=1024, # 3072 + intermediate_size=2048, # 24576 + num_hidden_layers=4, # 28 + num_attention_heads=4, # 16 + num_key_value_heads=4, # 16 + head_dim=256, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + pad_token_id=0, + # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + tie_word_embeddings=True, + attention_bias=False, + attention_dropout=0.0, + attn_implementation="eager", + ), + ), +} +if LLAMA4_AVAILABLE: + MINI_MODEL_SETUPS["mini_llama4"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_llama4, + liger_kernel_patch_revert_func=revert_liger_kernel_to_llama4, + model_class=Llama4ForCausalLM, + mini_model_config=Llama4TextConfig( + bos_token_id=1, # None + eos_token_id=2, # 151329, 151336, 151338 + pad_token_id=2, # 151329 + partial_rotary_factor=1.0, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 6144 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, # 32768 + num_attention_heads=8, # 48 + num_hidden_layers=4, # 61 + num_key_value_heads=2, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 151552 + attention_bias=True, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if QWEN3_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3, + model_class=Qwen3ForCausalLM, + mini_model_config=Qwen3Config( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=896, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=32768, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-6, + sliding_window=131072, + tie_word_embeddings=True, + use_cache=True, + vocab_size=32000, + attn_implementation="sdpa", + ), + ) + + MINI_MODEL_SETUPS["mini_qwen3_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_moe, + model_class=Qwen3MoeForCausalLM, + mini_model_config=Qwen3MoeConfig( + vocab_size=32000, # 151936 + hidden_size=896, + intermediate_size=4864, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + attention_bias=False, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + decoder_sparse_step=1, + moe_intermediate_size=768, + num_experts_per_tok=2, + num_experts=8, + norm_topk_prob=False, + output_router_logits=False, + router_aux_loss_coef=0.001, + mlp_only_layers=None, + ), + ) + +if GEMMA3_AVAILABLE: + MINI_MODEL_SETUPS["mini_gemma3_text"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_gemma3_text, + liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma3_text, + model_class=Gemma3ForCausalLM, + mini_model_config=Gemma3TextConfig( + vocab_size=32000, # 262144 + hidden_size=1024, # 1152 + intermediate_size=2048, # 6912 + num_hidden_layers=4, # 26 + num_attention_heads=4, + num_key_value_heads=1, + head_dim=256, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=8192, # 32768 + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + pad_token_id=0, + bos_token_id=2, + eos_token_id=1, + tie_word_embeddings=True, + attention_bias=False, + attention_dropout=0.0, + attn_implementation="eager", + ), + ) + +if MLLAMA_AVAILABLE: + MINI_MODEL_SETUPS["mini_mllama"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_mllama, + liger_kernel_patch_revert_func=revert_liger_kernel_to_mllama, + model_class=MllamaForCausalLM, + mini_model_config=MllamaTextConfig( + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + pad_token_id=2, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=131_072, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 40 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + attn_implementation="sdpa", # default value, pytorch native attention + rope_scaling=dict( + factor=8.0, + high_freq_factor=4.0, + low_freq_factor=1.0, + original_max_position_embeddings=8192, + rope_type="llama3", + rope_theta=500_000, + ) + if not IS_TRANSFORMERS_V5_OR_LATER + else None, + ), + ) + +if QWEN2_VL_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen2_vl"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen2_vl, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen2_vl, + model_class=Qwen2VLForConditionalGeneration, + mini_model_config=Qwen2VLConfig( + attention_dropout=0.0, + # bos and eos set to match the Mistral-7B tokenizer used to create the test dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 151643 + eos_token_id=2, # 151645 + vision_start_token_id=32765, # vocab_size - 5 + vision_end_token_id=32766, # vocab_size - 4 + vision_token_id=32767, # vocab_size - 3 + image_token_id=32768, # vocab_size - 2 + video_token_id=32769, # vocab_size - 1 + hidden_act="silu", + hidden_size=1536, # 8192 + initializer_range=0.02, + intermediate_size=4864, # 29568 + max_position_embeddings=32768, + max_window_layers=4, # 80 + num_attention_heads=12, # 64 + num_hidden_layers=4, # 80 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-6, # 1e-5 + **( + dict(rope_parameters=dict(mrope_section=[16, 24, 24])) # (temporal, height, width) + if IS_TRANSFORMERS_V5_OR_LATER + else dict(rope_scaling=dict(type="mrope", mrope_section=[16, 24, 24])) + ), + sliding_window=4096, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32768, # 152064 # >32k, Mistral-7B tokenizer vocab size + use_sliding_window=False, + vision_config={ + "depth": 4, # 32 + "embed_dim": 1280, + "mlp_ratio": 4, + "num_heads": 16, + "in_chans": 3, + "hidden_size": 128, # 1536 + "patch_size": 14, + "spatial_merge_size": 2, + "spatial_patch_size": 14, + "temporal_patch_size": 2, + }, + attn_implementation="sdpa", + ), + ) + +if QWEN2_5_VL_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen2_5_vl"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen2_5_vl, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen2_5_vl, + model_class=Qwen2_5_VLForConditionalGeneration, + mini_model_config=Qwen2_5_VLConfig( + attention_dropout=0.0, + # bos and eos set to match the Mistral-7B tokenizer used to create the test dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 151643 + eos_token_id=2, # 151645 + vision_start_token_id=32765, # vocab_size - 5 + vision_end_token_id=32766, # vocab_size - 4 + vision_token_id=32767, # vocab_size - 3 + image_token_id=32768, # vocab_size - 2 + video_token_id=32769, # vocab_size - 1 + hidden_act="silu", + hidden_size=1536, # 8192 + initializer_range=0.02, + intermediate_size=4864, # 29568 + max_position_embeddings=32768, + max_window_layers=4, # 80 + num_attention_heads=12, # 64 + num_hidden_layers=4, # 80 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-6, # 1e-5 + **( + dict(rope_parameters=dict(mrope_section=[16, 24, 24])) # (temporal, height, width) + if IS_TRANSFORMERS_V5_OR_LATER + else dict(rope_scaling=dict(type="mrope", mrope_section=[16, 24, 24])) + ), + sliding_window=4096, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32768, # 152064 # >32k, Mistral-7B tokenizer vocab size + use_sliding_window=False, + vision_config={ + "depth": 4, # 32 + "hidden_act": "silu", + "hidden_size": 128, # 1280 + "intermediate_size": 256, # 3420 + "num_heads": 16, + "in_chans": 3, + "out_hidden_size": 128, # 3584 + "patch_size": 14, + "spatial_merge_size": 2, + "spatial_patch_size": 14, + "window_size": 112, + "fullatt_block_indexes": [7, 15, 23, 31], + "tokens_per_second": 2, + "temporal_patch_size": 2, + }, + attn_implementation="sdpa", + ), + ) + +if QWEN3_VL_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3_vl"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_vl, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_vl, + model_class=Qwen3VLForConditionalGeneration, + mini_model_config=Qwen3VLConfig( + tie_word_embeddings=False, + image_token_id=31997, + video_token_id=31998, + vision_start_token_id=31995, + vision_end_token_id=31996, + text_config=dict( + attention_dropout=0.0, + attn_implementation="sdpa", + bos_token_id=1, + eos_token_id=2, + head_dim=112, + hidden_act="silu", + hidden_size=896, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=32768, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + pad_token_id=2, + rms_norm_eps=1e-6, + sliding_window=131072, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, + rope_scaling=dict( + type="mrope", + mrope_section=[16, 24, 24], # (temporal, height, width) + ) + if not IS_TRANSFORMERS_V5_OR_LATER + else None, + ), + vision_config=dict( + depth=4, + hidden_size=128, + initializer_range=0.02, + intermediate_size=256, + num_heads=8, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=896, + num_position_embeddings=576, + deepstack_visual_indexes=[1, 2, 3], + ), + ), + ) + +if QWEN3_VL_MOE_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3_vl_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_vl_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_vl_moe, + model_class=Qwen3VLMoeForConditionalGeneration, + mini_model_config=Qwen3VLMoeConfig( + tie_word_embeddings=False, + image_token_id=31997, + video_token_id=31998, + vision_start_token_id=31995, + vision_end_token_id=31996, + text_config=Qwen3VLMoeTextConfig( + attention_dropout=0.0, + attention_bias=False, + attn_implementation="sdpa", + bos_token_id=1, + eos_token_id=2, + head_dim=112, + hidden_act="silu", + hidden_size=896, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=32768, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + pad_token_id=2, + rms_norm_eps=1e-6, + sliding_window=131072, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, + decoder_sparse_step=1, + moe_intermediate_size=3072, + num_experts_per_tok=2, + num_experts=4, + mlp_only_layers=[], + ).to_dict(), + vision_config=Qwen3VLMoeVisionConfig( + depth=4, + hidden_size=128, + initializer_range=0.02, + intermediate_size=256, + num_heads=8, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=896, + num_position_embeddings=576, + deepstack_visual_indexes=[1, 2, 3], + ).to_dict(), + ), + ) + +if GRANITE_AVAILABLE: + MINI_MODEL_SETUPS["mini_granite3"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_granite, + liger_kernel_patch_revert_func=revert_liger_kernel_to_granite, + model_class=GraniteForCausalLM, + mini_model_config=GraniteConfig( + attention_bias=False, + attention_dropout=0.0, + # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=8192, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + num_key_value_heads=2, # 8 + pretraining_tp=1, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + logits_scaling=4.0, + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if LLAVA_AVAILABLE: + # https://huggingface.co/llava-hf/llava-1.5-7b-hf + MINI_MODEL_SETUPS["mini_llava"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_llava, + liger_kernel_patch_revert_func=revert_liger_kernel_to_llava, + model_class=LlavaForConditionalGeneration, + mini_model_config=LlavaConfig( + text_config=LlamaConfig( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=2048, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + pretraining_tp=1, + tie_word_embeddings=False, + use_cache=True, + max_position_embeddings=4096, # llava-1.5-7b-hf + rms_norm_eps=1e-05, # llava-1.5-7b-hf + vocab_size=32064, # llava-1.5-7b-hf + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + vision_config=CLIPVisionConfig( + hidden_size=1024, + image_size=336, + intermediate_size=2048, # 4096 + model_type="clip_vision_model", + num_attention_heads=4, # 16 + num_hidden_layers=4, # 24 + patch_size=14, + projection_dim=768, + vocab_size=32000, + ), + vocab_size=32064, + ignore_index=-100, + pad_token_id=4, + image_token_index=3, + projector_hidden_act="gelu", + vision_feature_layer=-2, + vision_feature_select_strategy="default", + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if OLMO2_AVAILABLE: + MINI_MODEL_SETUPS["mini_olmo2"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_olmo2, + liger_kernel_patch_revert_func=revert_liger_kernel_to_olmo2, + model_class=Olmo2ForCausalLM, + mini_model_config=Olmo2Config( + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + pad_token_id=2, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 40 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if OLMO3_AVAILABLE: + MINI_MODEL_SETUPS["mini_olmo3"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_olmo3, + liger_kernel_patch_revert_func=revert_liger_kernel_to_olmo3, + model_class=Olmo3ForCausalLM, + mini_model_config=Olmo3Config( + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + pad_token_id=2, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 40 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if GLM4_AVAILABLE: + MINI_MODEL_SETUPS["mini_glm4"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_glm4, + liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4, + model_class=Glm4ForCausalLM, + mini_model_config=Glm4Config( + bos_token_id=1, # None + eos_token_id=2, # 151329, 151336, 151338 + pad_token_id=2, # 151329 + partial_rotary_factor=0.5, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 6144 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, # 32768 + num_attention_heads=8, # 48 + num_hidden_layers=4, # 61 + num_key_value_heads=2, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 151552 + attention_bias=True, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if GLM4V_AVAILABLE: + MINI_MODEL_SETUPS["mini_glm4v"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_glm4v, + liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4v, + model_class=Glm4vForConditionalGeneration, + mini_model_config=Glm4vConfig( + bos_token_id=1, # None + eos_token_id=2, # 151329, 151336, 151338 + pad_token_id=2, # 151329 + image_token_id=151343, + video_token_id=151344, + image_start_token_id=151339, + image_end_token_id=151340, + video_start_token_id=151341, + video_end_token_id=151342, + partial_rotary_factor=0.5, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 6144 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, # 32768 + num_attention_heads=8, # 48 + num_hidden_layers=4, # 61 + num_key_value_heads=2, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 151552 + attention_bias=True, + attn_implementation="sdpa", # default value, pytorch native attention + text_config={ + "partial_rotary_factor": 0.5, + "hidden_act": "silu", + "hidden_size": 1024, + "intermediate_size": 2048, + "max_position_embeddings": 4096, + "num_attention_heads": 8, + "num_hidden_layers": 4, + "num_key_value_heads": 2, + "rms_norm_eps": 1e-5, + "vocab_size": 32000, + "attention_bias": True, + **( + {"rope_scaling": {"type": "default", "mrope_section": [8, 12, 12]}} + if not IS_TRANSFORMERS_V5_OR_LATER + else {} + ), + "pad_token_id": None, + }, + vision_config={ + "depth": 4, # 32 + "hidden_act": "silu", + "hidden_size": 128, # 1280 + "intermediate_size": 256, # 3420 + "num_heads": 16, + "in_chans": 3, + "out_hidden_size": 128, # 3584 + "patch_size": 14, + "spatial_merge_size": 2, + "temporal_patch_size": 2, + }, + ), + ) +if GLM4V_MOE_AVAILABLE: + MINI_MODEL_SETUPS["mini_glm4v_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_glm4v_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4v_moe, + model_class=Glm4vMoeForConditionalGeneration, + mini_model_config=Glm4vMoeConfig( + bos_token_id=1, # None + eos_token_id=2, # 151329, 151336, 151338 + pad_token_id=2, # 151329 + image_token_id=151343, + video_token_id=151344, + image_start_token_id=151339, + image_end_token_id=151340, + video_start_token_id=151341, + video_end_token_id=151342, + partial_rotary_factor=0.5, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 6144 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, # 32768 + num_attention_heads=8, # 48 + num_hidden_layers=4, # 61 + num_key_value_heads=2, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 151552 + attention_bias=True, + attn_implementation="sdpa", # default value, pytorch native attention + text_config={ + "partial_rotary_factor": 0.5, + "hidden_act": "silu", + "hidden_size": 1024, + "intermediate_size": 2048, + "max_position_embeddings": 4096, + "num_attention_heads": 8, + "num_hidden_layers": 4, + "num_key_value_heads": 2, + "rms_norm_eps": 1e-5, + "vocab_size": 32000, + "attention_bias": True, + "attention_dropout": 0.0, + "moe_intermediate_size": 1408, + "num_experts_per_tok": 2, + "n_shared_experts": 1, + "n_routed_experts": 8, + "routed_scaling_factor": 1.0, + "n_group": 1, + "topk_group": 1, + "first_k_dense_replace": 1, + "norm_topk_prob": True, + **( + {"rope_scaling": {"type": "default", "mrope_section": [8, 12, 12]}} + if not IS_TRANSFORMERS_V5_OR_LATER + else {} + ), + }, + vision_config={ + "depth": 4, # 32 + "hidden_act": "silu", + "hidden_size": 128, # 1280 + "intermediate_size": 256, # 3420 + "num_heads": 16, + "in_chans": 3, + "out_hidden_size": 128, # 3584 + "patch_size": 14, + "spatial_merge_size": 2, + "temporal_patch_size": 2, + }, + ), + ) + +if SMOLLM3_AVAILABLE: + MINI_MODEL_SETUPS["mini_smollm3"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_smollm3, + liger_kernel_patch_revert_func=revert_liger_kernel_to_smollm3, + model_class=SmolLM3ForCausalLM, + mini_model_config=SmolLM3Config( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + pad_token_id=2, # 128000 + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=8192, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + num_key_value_heads=2, # 8 + pretraining_tp=1, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if INTERNVL_AVAILABLE: + MINI_MODEL_SETUPS["mini_internvl"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_internvl, + liger_kernel_patch_revert_func=revert_liger_kernel_to_internvl, + model_class=InternVLForConditionalGeneration, + mini_model_config=InternVLConfig( + text_config=Qwen2Config( + rms_norm_eps=1e-5, + hidden_size=256, # 1024 + intermediate_size=1024, # 4096 + hidden_act="silu", + num_hidden_layers=4, # 24 + num_attention_heads=4, # 16 + num_key_value_heads=2, # 16 + max_position_embeddings=4096, # 8192 + vocab_size=32000, # 151936 + bos_token_id=1, + eos_token_id=2, + pad_token_id=2, + tie_word_embeddings=False, + ), + vision_config={ + "hidden_size": 256, # 1024 + "intermediate_size": 1024, # 4096 + "num_hidden_layers": 4, # 24 + "num_attention_heads": 4, # 16 + }, + image_token_id=10, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + +if FALCONH1_AVAILABLE: + MINI_MODEL_SETUPS["mini_falcon_h1"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_falcon_h1, + liger_kernel_patch_revert_func=revert_liger_kernel_to_falcon_h1, + model_class=FalconH1ForCausalLM, + mini_model_config=FalconH1Config( + model_type="falcon_h1", + vocab_size=32000, + hidden_size=256, # 4096 + num_hidden_layers=4, # 24 + num_attention_heads=4, # 32 + num_key_value_heads=2, # 8 + intermediate_size=1024, # 11008 + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + mamba_d_ssm=128, # 1024 + mamba_n_heads=16, # 128 + mamba_d_state=32, # 245 + mamba_d_conv=2, # 4 + ), + ) + +if QWEN3NEXT_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3_next"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_next, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_next, + model_class=Qwen3NextForCausalLM, + mini_model_config=Qwen3NextConfig( # Copypaste Qwen3MoeConfig + vocab_size=32000, + hidden_size=896, + intermediate_size=4864, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + attention_bias=False, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + decoder_sparse_step=1, + moe_intermediate_size=768, + num_experts_per_tok=2, + num_experts=8, + norm_topk_prob=False, + output_router_logits=False, + router_aux_loss_coef=0.001, + # config.dtype must be set if fla installed since there's a bug in the original code (No torch.get_current_dtype()) + # https://github.com/huggingface/transformers/blob/v4.57.1/src/transformers/models/qwen3_next/modeling_qwen3_next.py#L613 + dtype=torch.float32, + ), + ) + +if QWEN3_5_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen3_5"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen3_5, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_5, + model_class=Qwen3_5ForCausalLM, + mini_model_config=Qwen3_5TextConfig( + vocab_size=32000, + hidden_size=896, + intermediate_size=4864, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + attention_bias=False, + attention_dropout=0.0, + head_dim=128, + linear_conv_kernel_dim=4, + linear_key_head_dim=64, + linear_value_head_dim=64, + linear_num_key_heads=8, + linear_num_value_heads=8, + layer_types=["linear_attention", "linear_attention", "linear_attention", "full_attention"], + dtype=torch.float32, + ), + ) + + +if HUNYUAN_V1_AVAILABLE: + MINI_MODEL_SETUPS["mini_hunyuan_v1"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_hunyuan_v1_dense, + liger_kernel_patch_revert_func=revert_liger_kernel_to_hunyuan_v1, + model_class=HunYuanDenseV1ForCausalLM, + mini_model_config=HunYuanDenseV1Config( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + num_hidden_layers=4, + hidden_size=896, + intermediate_size=4864, + num_attention_heads=8, + head_dim=112, + rms_norm_eps=1e-6, + tie_word_embeddings=True, + max_position_embeddings=32768, + initializer_range=0.02, + norm_eps=1e-6, + num_key_value_heads=2, + partial_rotary_factor=1.0, + vocab_size=32000, + use_cache=True, + attn_implementation="sdpa", + ), + ) + + MINI_MODEL_SETUPS["mini_hunyuan_v1_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_hunyuan_v1_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_hunyuan_v1_moe, + model_class=HunYuanMoEV1ForCausalLM, + mini_model_config=HunYuanMoEV1Config( + hidden_act="silu", + attention_dropout=0.0, + num_hidden_layers=4, + hidden_size=896, + intermediate_size=4864, + num_attention_heads=8, + head_dim=112, + rms_norm_eps=1e-6, + tie_word_embeddings=True, + max_position_embeddings=32768, + initializer_range=0.02, + norm_eps=1e-6, + num_key_value_heads=2, + partial_rotary_factor=1.0, + vocab_size=32000, + num_experts=8, + moe_topk=2, + use_cache=True, + attn_implementation="sdpa", + ), + ) + +if EXAONE4_AVAILABLE: + MINI_MODEL_SETUPS["mini_exaone4"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_exaone4, + liger_kernel_patch_revert_func=revert_liger_kernel_to_exaone4, + model_class=Exaone4ForCausalLM, + mini_model_config=Exaone4Config( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=896, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=32768, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-5, + tie_word_embeddings=True, + use_cache=True, + vocab_size=32000, + attn_implementation="sdpa", + pad_token_id=None, + ), + ) + + +def create_model(model_name="mini_llama3"): + """ + Create a mini version model + The commented values are the original values + """ + model_config = MINI_MODEL_SETUPS[model_name].mini_model_config + model_class = MINI_MODEL_SETUPS[model_name].model_class + return model_class(model_config) + + +@require_deterministic +def run_mini_model( + model_name="mini_llama3", + num_steps=100, + dtype=torch.bfloat16, + lr=1e-5, + with_liger=False, +): + # If we move it to the beginning of test_mini_model, the two runs are initialized with different weights. + # This is due to RNG (Random Number Generator). The formula of RNG progression is x_(n+1) = (a * x_n + c) % m + # Everytime RNG is used, like randomly initialzing weight, the RNG progresses to the next state. + # Therefore, we have to reset RNG before we create the model to ensure the weight initialization started from the same RNG state. + + set_seed(42) + + revert_kwargs = {"model_config": MINI_MODEL_SETUPS[model_name]} + if "mllama" in model_name: + revert_kwargs["model_type"] = "causal_lm" + + if with_liger is True: + kwargs = { + "rope": True, + "rms_norm": True, + } + + if "glm4" in model_name or "llama4" in model_name or "qwen3_next" in model_name or "qwen3_5" in model_name: + kwargs["rope"] = False + + model_supports_layer_norm = "qwen2_vl" in model_name + if model_supports_layer_norm: + kwargs["layer_norm"] = True + + if "gemma" in model_name: + kwargs["geglu"] = True + else: + kwargs["swiglu"] = True + + if "llava" in model_name: + apply_liger_kernel_to_llama(**kwargs) + + kwargs["fused_linear_cross_entropy"] = False + kwargs["cross_entropy"] = False + + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) + else: + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs) + + model = create_model(model_name).to(dtype).to(device) + + train_dataset = load_from_disk(DEFAULT_DATASET_PATH) + loader = DataLoader(train_dataset, batch_size=16, shuffle=False, collate_fn=simple_collate_fn) + loader_iter = iter(loader) + optimizer = torch.optim.AdamW(model.parameters(), lr=lr) + + loss_list = [] + + for i in range(num_steps): + batch = next(loader_iter).to(model.device) + optimizer.zero_grad() + output = model(**batch) + output.loss.backward() + optimizer.step() + print(f"Step {i}, Loss: {output.loss.item()}") + loss_list.append(output.loss.item()) + + topk_logprobs = get_topk(get_logprobs(output.logits)) + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs) + return { + "loss": loss_list, + "topk_logprobs": topk_logprobs.values, + "model": model, + } + + +@pytest.mark.parametrize( + "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logprobs_atol, logprobs_rtol, param_atol, param_rtol", + [ + pytest.param( + "mini_llama4", # llama4 requires slightly larger tolerances to pass this test after bug fix to llama4 in transformers v5.0.0 + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-3, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=[ + pytest.mark.skipif( + not LLAMA4_AVAILABLE, + reason="Llama4 not available in this version of transformers", + ), + pytest.mark.xfail( + reason=( + "RuntimeError: Expected query, key, and value to have the same dtype, but got query.dtype:" + " float key.dtype: c10::BFloat16 and value.dtype: c10::BFloat16 instead." + ) + ), + ], + ), + ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 5e-3, 1e-5, 5e-3, 1e-5), + pytest.param( + "mini_llava", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not LLAVA_AVAILABLE, + reason="LLaVa not available in this version of transformers", + ), + ), + pytest.param( + "mini_mllama", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not MLLAMA_AVAILABLE, + reason="Mllama not available in this version of transformers", + ), + ), + pytest.param( + "mini_gemma3_text", + 32, + 1e-5, + torch.float32, + 1e-8, + 1e-4, + 5e-2, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not GEMMA3_AVAILABLE, + reason="Gemma3 not available in this version of transformers", + ), + ), + ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), + pytest.param( + "mini_qwen3", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not QWEN3_AVAILABLE, + reason="Qwen3 not available in this version of transformers", + ), + ), + pytest.param( + "mini_qwen3_moe", + 32, + 1e-5, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not QWEN3_AVAILABLE, + reason="Qwen3 not available in this version of transformers", + ), + ), + pytest.param( + "mini_qwen2_vl", + 32, + 1e-4, + torch.float32, + 1e-8, + 2e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not QWEN2_VL_AVAILABLE, + reason="Qwen2-VL not available in this version of transformers", + ), + ), + pytest.param( + "mini_qwen2_5_vl", + 32, + 1e-4, + torch.float32, + 1e-8, + 2e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not QWEN2_5_VL_AVAILABLE, + reason="Qwen2.5-VL not available in this version of transformers", + ), + ), + pytest.param( + "mini_qwen3_vl", + 32, + 1e-4, + torch.float32, + 1e-8, + 2e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not QWEN3_VL_AVAILABLE, + reason="Qwen3-VL not available in this version of transformers", + ), + ), + pytest.param( + "mini_qwen3_vl_moe", + 32, + 1e-4, + torch.float32, + 1e-8, + 2e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=[ + pytest.mark.skipif( + not QWEN3_VL_MOE_AVAILABLE, + reason="Qwen3-VL-MoE not available in this version of transformers", + ), + pytest.mark.skipif( + True, + reason="Flaky test", + ), + ], + ), + pytest.param( + "mini_olmo2", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not OLMO2_AVAILABLE, + reason="OLMO2 not available in this version of transformers", + ), + ), + pytest.param( + "mini_olmo3", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not OLMO3_AVAILABLE, + reason="OLMO3 not available in this version of transformers", + ), + ), + pytest.param( + "mini_glm4", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not GLM4_AVAILABLE, + reason="Glm4 not available in this version of transformers", + ), + ), + pytest.param( + "mini_glm4v", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not GLM4V_AVAILABLE, + reason="Glm4v not available in this version of transformers", + ), + ), + pytest.param( + "mini_glm4v_moe", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not GLM4V_MOE_AVAILABLE, + reason="Glm4v_moe not available in this version of transformers", + ), + ), + ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), + ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), + # TODO: mixtral is flaky so disable the test for now + # ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5), + # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match + ("mini_gemma1", 32, 1e-5, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), + ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), + ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), + pytest.param( + "mini_granite3", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-4, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not GRANITE_AVAILABLE, + reason="Granite not available in this version of transformers", + ), + ), + pytest.param( + "mini_smollm3", + 32, + 1e-4, + torch.bfloat16, + 1e-3, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=pytest.mark.skipif( + not SMOLLM3_AVAILABLE, + reason="Smollm3 not available in this version of transformers", + ), + ), + pytest.param( + "mini_internvl", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not INTERNVL_AVAILABLE, + reason="InternVL not available in this version of transformers", + ), + ), + pytest.param( + "mini_falcon_h1", + 32, + 1e-5, + torch.float32, + 1e-8, + 1e-4, + 4e-2, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not FALCONH1_AVAILABLE, + reason="FalconH1 not available in this version of transformers", + ), + ), + pytest.param( + "mini_qwen3_next", + 32, + 1e-5, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=[ + pytest.mark.skipif( + not QWEN3NEXT_AVAILABLE, + reason="Qwen3Next not available in this version of transformers", + ), + pytest.mark.skip( + reason="flash-linear-attention's ChunkGatedDeltaRuleFunction does not support float32.\n" + + " Torch's implementation takes too long" + ), + ], + ), + pytest.param( + "mini_qwen3_5", + 32, + 1e-5, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=[ + pytest.mark.skipif( + not QWEN3_5_AVAILABLE, + reason="Qwen3_5 not available in this version of transformers", + ), + pytest.mark.skip( + reason="flash-linear-attention's ChunkGatedDeltaRuleFunction does not support float32.\n" + + " Torch's implementation takes too long" + ), + ], + ), + pytest.param( + "mini_hunyuan_v1", + 32, + 1e-5, + torch.float32, + 1e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif( + not HUNYUAN_V1_AVAILABLE, + reason="Hunyuan_v1 not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_hunyuan_v1_moe", + 32, + 1e-5, + torch.float32, + 1e-2, + 5e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif( + not HUNYUAN_V1_AVAILABLE, + reason="Hunyuan_v1_moe not available in this version of transformers", + ), + ], + ), + pytest.param( + "mini_exaone4", + 32, + 1e-5, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not EXAONE4_AVAILABLE, + reason="EXAONE4 not available in this version of transformers", + ), + ), + ], +) +def test_mini_model( + model_name, + num_steps, + lr, + dtype, + loss_atol, + loss_rtol, + logprobs_atol, + logprobs_rtol, + param_atol, + param_rtol, +): + # Non-liger models should be initialized and tested first to avoid the module being overridden + + expected_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr) + + actual_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True) + + # Compare every step of the loss + assert_verbose_allclose( + torch.tensor([expected_output["loss"]]), + torch.tensor([actual_output["loss"]]), + atol=loss_atol, + rtol=loss_rtol, + extra_info="[Loss]", + ) + + # No logits are materialized + # import pdb; pdb.set_trace() + assert_verbose_allclose( + expected_output["topk_logprobs"], + actual_output["topk_logprobs"], + atol=logprobs_atol, + rtol=logprobs_rtol, + extra_info="[Top K Logprobs]", + ) + + # Compare the params from the last step + # Iterate over the model's parameters and compare them + for expected_param, actual_param in zip( + expected_output["model"].named_parameters(), + actual_output["model"].named_parameters(), + ): + assert_verbose_allclose( + expected_param[1], + actual_param[1], + atol=param_atol, + rtol=param_rtol, + extra_info="[Model parameters]", + ) diff --git a/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json b/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json new file mode 100755 index 0000000000000000000000000000000000000000..5c5c2adf67d113a9cc6896464cd3b98436b01fee --- /dev/null +++ b/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json @@ -0,0 +1,90 @@ +{ + "add_bos_token": true, + "add_eos_token": false, + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "3": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "4": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "5": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "6": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "7": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "boi_token": "", + "bos_token": "", + "chat_template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'model\n'}}\n{%- endif -%}\n", + "clean_up_tokenization_spaces": false, + "eoi_token": "", + "eos_token": "", + "extra_special_tokens": { + "boi_token": "", + "eoi_token": "", + "image_token": "" + }, + "image_token": "", + "model_max_length": 1000000000000000019884624838656, + "pad_token": "", + "processor_class": "Gemma3Processor", + "sp_model_kwargs": null, + "spaces_between_special_tokens": false, + "tokenizer_class": "GemmaTokenizer", + "unk_token": "", + "use_default_system_prompt": false +} \ No newline at end of file diff --git a/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json b/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json new file mode 100755 index 0000000000000000000000000000000000000000..6c7d5eec91a94408618dc343d1f1f10abff0bbe5 --- /dev/null +++ b/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json @@ -0,0 +1,61 @@ +{ + "add_bos_token": true, + "add_eos_token": false, + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "3": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "4": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "additional_special_tokens": [ + "" + ], + "bos_token": "", + "clean_up_tokenization_spaces": false, + "eos_token": "", + "model_max_length": 1000000000000000019884624838656, + "pad_token": "", + "processor_class": "PaliGemmaProcessor", + "sp_model_kwargs": {}, + "spaces_between_special_tokens": false, + "tokenizer_class": "GemmaTokenizer", + "unk_token": "", + "use_default_system_prompt": false, + "chat_template": "{% for message in messages %}{% if loop.index0 == 0 %}{{ bos_token }}{% endif %}{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>' }}{% if message['content'] is string %}{{ message['content'] }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' %}{{ '' }}{% elif content['type'] == 'text' %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{ '<|eot_id|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>' }}{% endif %}" +} \ No newline at end of file diff --git a/test/resources/fake_configs/HuggingFaceTB/SmolVLM2-256M-Video-Instruct/tokenizer_config.json b/test/resources/fake_configs/HuggingFaceTB/SmolVLM2-256M-Video-Instruct/tokenizer_config.json new file mode 100755 index 0000000000000000000000000000000000000000..e4042a1126290fdb96ece2a4ad8dd7108c5de484 --- /dev/null +++ b/test/resources/fake_configs/HuggingFaceTB/SmolVLM2-256M-Video-Instruct/tokenizer_config.json @@ -0,0 +1,1192 @@ +{ + "add_prefix_space": false, + "added_tokens_decoder": { + "0": { + "content": "<|endoftext|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "<|im_start|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "<|im_end|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "3": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "4": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "5": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "6": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "7": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "8": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "9": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "10": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "11": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "12": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "13": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "14": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "15": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "16": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49152": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49153": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49154": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49155": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49156": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49157": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49158": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49159": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49160": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49161": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49162": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49163": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49164": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49165": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49166": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49167": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49168": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49169": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49170": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49171": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49172": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49173": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49174": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49175": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49176": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49177": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49178": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49179": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49180": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49181": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49182": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49183": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49184": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49185": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49186": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49187": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49188": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49189": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49190": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49191": { + "content": "<|reserved_special_token_0|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49192": { + "content": "<|reserved_special_token_1|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49193": { + "content": "<|reserved_special_token_2|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49194": { + "content": "<|reserved_special_token_3|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49195": { + "content": "<|reserved_special_token_4|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49196": { + "content": "<|reserved_special_token_5|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49197": { + "content": "<|reserved_special_token_6|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49198": { + "content": "<|reserved_special_token_7|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49199": { + "content": "<|reserved_special_token_8|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49200": { + "content": "<|reserved_special_token_9|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49201": { + "content": "<|reserved_special_token_10|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49202": { + "content": "<|reserved_special_token_11|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49203": { + "content": "<|reserved_special_token_12|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49204": { + "content": "<|reserved_special_token_13|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49205": { + "content": "<|reserved_special_token_14|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49206": { + "content": "<|reserved_special_token_15|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49207": { + "content": "<|reserved_special_token_16|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49208": { + "content": "<|reserved_special_token_17|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49209": { + "content": "<|reserved_special_token_18|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49210": { + "content": "<|reserved_special_token_19|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49211": { + "content": "<|reserved_special_token_20|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49212": { + "content": "<|reserved_special_token_21|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49213": { + "content": "<|reserved_special_token_22|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49214": { + "content": "<|reserved_special_token_23|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49215": { + "content": "<|reserved_special_token_24|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49216": { + "content": "<|reserved_special_token_25|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49217": { + "content": "<|reserved_special_token_26|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49218": { + "content": "<|reserved_special_token_27|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49219": { + "content": "<|reserved_special_token_28|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49220": { + "content": "<|reserved_special_token_29|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49221": { + "content": "<|reserved_special_token_30|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49222": { + "content": "<|reserved_special_token_31|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49223": { + "content": "<|reserved_special_token_32|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49224": { + "content": "<|reserved_special_token_33|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49225": { + "content": "<|reserved_special_token_34|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49226": { + "content": "<|reserved_special_token_35|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49227": { + "content": "<|reserved_special_token_36|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49228": { + "content": "<|reserved_special_token_37|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49229": { + "content": "<|reserved_special_token_38|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49230": { + "content": "<|reserved_special_token_39|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49231": { + "content": "<|reserved_special_token_40|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49232": { + "content": "<|reserved_special_token_41|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49233": { + "content": "<|reserved_special_token_42|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49234": { + "content": "<|reserved_special_token_43|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49235": { + "content": "<|reserved_special_token_44|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49236": { + "content": "<|reserved_special_token_45|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49237": { + "content": "<|reserved_special_token_46|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49238": { + "content": "<|reserved_special_token_47|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49239": { + "content": "<|reserved_special_token_48|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49240": { + "content": "<|reserved_special_token_49|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49241": { + "content": "<|reserved_special_token_50|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49242": { + "content": "<|reserved_special_token_51|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49243": { + "content": "<|reserved_special_token_52|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49244": { + "content": "<|reserved_special_token_53|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49245": { + "content": "<|reserved_special_token_54|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49246": { + "content": "<|reserved_special_token_55|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49247": { + "content": "<|reserved_special_token_56|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49248": { + "content": "<|reserved_special_token_57|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49249": { + "content": "<|reserved_special_token_58|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49250": { + "content": "<|reserved_special_token_59|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49251": { + "content": "<|reserved_special_token_60|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49252": { + "content": "<|reserved_special_token_61|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49253": { + "content": "<|reserved_special_token_62|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49254": { + "content": "<|reserved_special_token_63|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49255": { + "content": "<|reserved_special_token_64|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49256": { + "content": "<|reserved_special_token_65|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49257": { + "content": "<|reserved_special_token_66|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49258": { + "content": "<|reserved_special_token_67|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49259": { + "content": "<|reserved_special_token_68|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49260": { + "content": "<|reserved_special_token_69|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49261": { + "content": "<|reserved_special_token_70|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49262": { + "content": "<|reserved_special_token_71|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49263": { + "content": "<|reserved_special_token_72|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49264": { + "content": "<|reserved_special_token_73|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49265": { + "content": "<|reserved_special_token_74|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49266": { + "content": "<|reserved_special_token_75|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49267": { + "content": "<|reserved_special_token_76|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49268": { + "content": "<|reserved_special_token_77|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49269": { + "content": "<|reserved_special_token_78|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49270": { + "content": "<|reserved_special_token_79|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49271": { + "content": "<|reserved_special_token_80|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49272": { + "content": "<|reserved_special_token_81|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49273": { + "content": "<|reserved_special_token_82|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49274": { + "content": "<|reserved_special_token_83|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49275": { + "content": "<|reserved_special_token_84|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49276": { + "content": "<|reserved_special_token_85|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49277": { + "content": "<|reserved_special_token_86|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49278": { + "content": "<|reserved_special_token_87|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "49279": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "additional_special_tokens": [ + "", + "", + "" + ], + "bos_token": "<|im_start|>", + "chat_template": "<|im_start|>{% for message in messages %}{{message['role'] | capitalize}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '' }}{% endif %}{% endfor %}\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}", + "clean_up_tokenization_spaces": false, + "end_of_utterance_token": "", + "eos_token": "", + "extra_special_tokens": { + "end_of_utterance_token": "", + "fake_image_token": "", + "global_image_token": "", + "image_token": "" + }, + "fake_image_token": "", + "global_image_token": "", + "image_token": "", + "legacy": false, + "model_max_length": 8192, + "pad_token": "<|im_end|>", + "processor_class": "SmolVLMProcessor", + "tokenizer_class": "GPT2Tokenizer", + "truncation_side": "left", + "unk_token": "<|endoftext|>", + "vocab_size": 49152 +} diff --git a/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json b/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json new file mode 100755 index 0000000000000000000000000000000000000000..c32625c74fdedbde4c654d205c66f6b3dc852454 --- /dev/null +++ b/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json @@ -0,0 +1,28 @@ +{ + "crop_size": { + "height": 336, + "width": 336 + }, + "do_center_crop": true, + "do_convert_rgb": true, + "do_normalize": true, + "do_rescale": true, + "do_resize": true, + "image_mean": [ + 0.48145466, + 0.4578275, + 0.40821073 + ], + "image_processor_type": "CLIPImageProcessor", + "image_std": [ + 0.26862954, + 0.26130258, + 0.27577711 + ], + "processor_class": "LlavaProcessor", + "resample": 3, + "rescale_factor": 0.00392156862745098, + "size": { + "shortest_edge": 336 + } +} \ No newline at end of file diff --git a/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json b/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json new file mode 100755 index 0000000000000000000000000000000000000000..8fbb221c7fdc95258d63f57d1e33aed4633068f6 --- /dev/null +++ b/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json @@ -0,0 +1,7 @@ +{ + "image_token": "", + "num_additional_image_tokens": 1, + "patch_size": 14, + "processor_class": "LlavaProcessor", + "vision_feature_select_strategy": "default" +} \ No newline at end of file diff --git a/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json b/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json new file mode 100755 index 0000000000000000000000000000000000000000..f9c6572a84cc8de54b5807d28398daf1bc106dbf --- /dev/null +++ b/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json @@ -0,0 +1,66 @@ +{ + "add_bos_token": true, + "add_eos_token": false, + "add_prefix_space": null, + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "3": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "4": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "bos_token": "", + "clean_up_tokenization_spaces": false, + "eos_token": "", + "extra_special_tokens": { + "image_token": "" + }, + "image_token": "", + "legacy": false, + "chat_template": "{% if not add_generation_prompt is defined %}{% set add_last_empty_assistant = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message.role == 'user' %}{{ '### User:\n' }}{% if message.content is not string %}{% for content in message.content %}{% if content.type == 'image' %}{{ '' }}{% elif content.type == 'text' %}{{ content.text }}{% else %}{# Do nothing #}{% endif %}{% endfor %}{% else %}{{ message.content }}{% endif %}{{ '\n\n' }}{% elif message.role == 'system' %}{{ '### System:\n' }}{% if message.content is not string %}{% for content in message.content %}{% if content.type == 'image' %}{{ '' }}{% elif content.type == 'text' %}{{ content.text }}{% else %}{# Do nothing #}{% endif %}{% endfor %}{% else %}{{ message.content }}{% endif %}{{ '\n\n' }}{% elif message.role == 'assistant' %}{{ '### Assistant:\n' }}{% if message.content is not string %}{% for content in message.content %}{% if content.type == 'text' %}{{ content.text }}{% else %}{# Do nothing #}{% endif %}{% endfor %}{% else %}{{ message.content }}{% endif %}{% else %}{{ '' }}{% endif %}{% endfor %}{% if not add_generation_prompt %}{{ eos_token }}{% elif add_generation_prompt %}{{ '### Assistant:\n' }}{% else %}{# Do nothing #}{% endif %}", + "model_max_length": 1000000000000000019884624838656, + "pad_token": "", + "padding_side": "left", + "processor_class": "LlavaProcessor", + "sp_model_kwargs": {}, + "tokenizer_class": "LlamaTokenizer", + "trust_remote_code": false, + "unk_token": "", + "use_default_system_prompt": false, + "return_token_type_ids": false +} \ No newline at end of file diff --git a/test/resources/fake_configs/OpenGVLab/InternVL3-1B-hf/tokenizer_config.json b/test/resources/fake_configs/OpenGVLab/InternVL3-1B-hf/tokenizer_config.json new file mode 100755 index 0000000000000000000000000000000000000000..f47a164c28afd23a0a1994b7ebd8afa7b6f33c32 --- /dev/null +++ b/test/resources/fake_configs/OpenGVLab/InternVL3-1B-hf/tokenizer_config.json @@ -0,0 +1,307 @@ +{ + "add_bos_token": false, + "add_prefix_space": false, + "added_tokens_decoder": { + "151643": { + "content": "<|endoftext|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151644": { + "content": "<|im_start|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151645": { + "content": "<|im_end|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151646": { + "content": "<|object_ref_start|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151647": { + "content": "<|object_ref_end|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151648": { + "content": "<|box_start|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151649": { + "content": "<|box_end|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151650": { + "content": "<|quad_start|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151651": { + "content": "<|quad_end|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151652": { + "content": "<|vision_start|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151653": { + "content": "<|vision_end|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151654": { + "content": "<|vision_pad|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151655": { + "content": "<|image_pad|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151656": { + "content": "<|video_pad|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151657": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "151658": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "151659": { + "content": "<|fim_prefix|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "151660": { + "content": "<|fim_middle|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "151661": { + "content": "<|fim_suffix|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "151662": { + "content": "<|fim_pad|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "151663": { + "content": "<|repo_name|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "151664": { + "content": "<|file_sep|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "151665": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151666": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151667": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151668": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151669": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151670": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151671": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151672": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151673": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151674": { + "content": "