{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Creating a new press\n", "\n", "In this guide, we will walk you through the process of creating a new press." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from dataclasses import dataclass\n", "from contextlib import contextmanager\n", "\n", "import torch\n", "from torch import nn\n", "from transformers import pipeline\n", "\n", "from kvpress import BasePress, KnormPress, ScorerPress" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.\n", "Device set to use cuda:0\n" ] } ], "source": [ "# Load pipeline\n", "\n", "device = \"cuda:0\"\n", "ckpt = \"Qwen/Qwen2.5-1.5B-Instruct\"\n", "attn_implementation = \"flash_attention_2\"\n", "pipe = pipeline(\"kv-press-text-generation\", model=ckpt, device=device, dtype=\"auto\", model_kwargs={\"attn_implementation\":attn_implementation})" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Load data\n", "\n", "context = \"In this step-by-step guide, you will learn how to create a new press in kvpress !\"\n", "question = \"\\nWhat is the purpose of this guide?\"\n", "tokens = pipe.tokenizer(context, return_tensors=\"pt\").to(device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Understanding how press work" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A press registers a forward hook to each attention layer during the pre-filling phase. Immediately after the forward pass, the hook is called, and it compresses the KV cache." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)\n", "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Cache shape w/o press: torch.Size([1, 2, 20, 128])\n", "Cache shape w/ press: torch.Size([1, 2, 15, 128])\n", "\n", "The purpose of this step-by-step guide is to provide instructions on how to create a new press in kvpress. The guide is designed to help users understand the process of setting up a new press in the kvpress platform.\n" ] } ], "source": [ "compression_ratio = 0.25\n", "press = KnormPress(compression_ratio)\n", "\n", "with torch.no_grad():\n", " outputs_without_press = pipe.model(**tokens, output_hidden_states=True)\n", "\n", "with torch.no_grad(), press(pipe.model):\n", " output_with_press = pipe.model(**tokens)\n", "\n", "print(f\"Cache shape w/o press: {outputs_without_press.past_key_values[0][0].shape}\")\n", "print(f\"Cache shape w/ press: {output_with_press.past_key_values[0][0].shape}\\n\")\n", "\n", "# The `KVPressTextGenerationPipeline` simply applies the `press` as above on the context tokens (see `_forward` method for more details).\n", "print(pipe(context, question=question, press=press)[\"answer\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Creating your own press\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.1 Updating the `score` method" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The easiest way to create a new press is to create a class that inherits from `ScorerPress` and implement a `score` method that computes the score for each key-value pair.\n", "\n", "The arguments of the `score` method are obtained from the forward hook:\n", "- `module`: the attention layer\n", "- `hidden_states`: the input of the attention layer\n", "- `keys` and `values`: the key-value pairs from the attention layer\n", "- `attentions`: the attention weights, only available with `attn_implementation=\"eager\"`\n", "\n", "In this first example, we will reproduce the `KnormPress` where the score of a key-value pair is simply the opposite of the norm of the key vector." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class MyKnormPress(ScorerPress):\n", " def score(\n", " self,\n", " module: nn.Module,\n", " hidden_states: torch.Tensor,\n", " keys: torch.Tensor,\n", " values: torch.Tensor,\n", " attentions: torch.Tensor,\n", " kwargs,\n", " ) -> torch.Tensor:\n", "\n", " scores = -keys.norm(dim=-1)\n", "\n", " # For demonstration, we show some details on the shape for the first layer\n", " if module.layer_idx == 0:\n", " print(f\"module: {module}\")\n", " print(f\"Number of key value heads: {module.config.num_key_value_heads}\")\n", " print(f\"Sequence length: {hidden_states.shape[1]}\")\n", " print()\n", " print(f\"hidden_states shape: {hidden_states.shape}\")\n", " print(f\"keys shape: {keys.shape}\") # shape (bhnd)\n", " print(f\"values shape: {values.shape}\") # shape (bhnd)\n", " print(f\"score shape: {scores.shape}\") # shape (bhn)\n", " print()\n", " \n", " return scores\n", "\n", "\n", "press = MyKnormPress(compression_ratio)\n", "print(pipe(context, question=question, press=press)[\"answer\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.2 Updating the `compress` method " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `compress` method defined in the `BasePress` contains the core logic of the compression and returns compressed keys and values. For instance, in the `ScorerPress` the `compress` calls the `score` method (which is specific to `ScorerPress`) and prune the key-value pairs based on the scores.\n", "\n", "The following example will show how it works. We will re-implement the `StreamingLLMPress` in a more compact way." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@dataclass\n", "class MyStreamingLLMPress(BasePress):\n", " n_first: int = 1\n", " n_last: int = 8\n", "\n", " def compress(\n", " self,\n", " module: nn.Module,\n", " hidden_states: torch.Tensor,\n", " keys: torch.Tensor,\n", " values: torch.Tensor,\n", " attentions: torch.Tensor,\n", " kwargs: dict,\n", " ) -> tuple[torch.Tensor, torch.Tensor]:\n", "\n", " mask = torch.ones(keys.shape[-2], dtype=torch.bool, device=keys.device)\n", " mask[self.n_first : -self.n_last] = False\n", " return keys[:, :, mask, :], values[:, :, mask, :]\n", "\n", "\n", "for n_last in [2, 4, 8]:\n", " press = MyStreamingLLMPress(n_last=n_last)\n", " print(f\"\\nn_last: {n_last}\")\n", " print(f\"Last tokens seen by the model: {pipe.tokenizer.decode(tokens.input_ids[0, -n_last:])}\")\n", " print(f\"Answer: {pipe(context, question=question, press=press)['answer']}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that in the `compress` method is itself used in the `forward_hook` method which ensures quantization is handled properly and that the compression is only performed during prefilling. While we don't recommend to change the `forward_hook` method directly, you can still modify it if you need to !" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.3 Head-wise compression\n", "\n", "Since 0.2.0, kvpress support head-wise compression, where the KV cache of each head might be compressed by a different compression ratio. \n", "\n", "To achieve proper head-wise compression, one should implement a new kernel for attention along with a custom cache class. Instead, the current implementation fakes head-wise compression by updating the pruned keys by a fake key so that the output of the attention layer is not affected. This is implemented through `kvpress.attention_patch.patch_attention_functions`.\n", "\n", "To implement a method that compresses the KV cache head-wise, one should instantiate the `masked_key_indices` as outlined below." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "compression_ratio: 0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Answer: The purpose of this step-by-step guide is to provide a comprehensive and easy-to-follow tutorial on how to create a new press in the KVPress platform. The guide is designed to help users understand the process of setting up a new press, including the\n", "\n", "compression_ratio: 0.25\n", "Answer: The purpose of this guide is to provide a step-by-step process for creating a new press in KVPRESS, which is a popular open-source web server. The guide will cover the necessary steps to set up and configure a new press, including installing\n", "\n", "compression_ratio: 0.9\n", "Answer: This guide is not a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a\n" ] } ], "source": [ "@dataclass\n", "class RandomHeadPress(BasePress):\n", "\n", " compression_ratio: float = 0.0\n", "\n", " def compress(self, module, hidden_states, keys, values, attentions, kwargs):\n", " assert keys.shape[0] == 1, \"Only batch size 1 is supported\"\n", " scores = torch.rand(keys.shape[:-1], device=keys.device)\n", " mask = scores < torch.quantile(scores, self.compression_ratio)\n", " module.masked_key_indices = torch.nonzero(mask, as_tuple=True)\n", " \n", " return keys, values\n", "\n", "for compression_ratio in [0, 0.25, 0.9]:\n", " press = RandomHeadPress(compression_ratio)\n", " print(f\"\\ncompression_ratio: {compression_ratio}\")\n", " print(f\"Answer: {pipe(context, question=question, press=press)['answer']}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Contributing to kvpress" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "All presses should be stored in the `presses` directory. Before opening a pull request with your new press, make sure to \n", "- register it in the `__init__.py` file of repository\n", "- register the press in [default_presses.py](tests/default_presses.py)\n", "- update the README" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 2 }