"...text-generation-inference.git" did not exist on "deb440b3a2179b1eccce9cf5dc1d4ff0e8a03135"
Unverified Commit ebd0e1c1 authored by Glen Liu's avatar Glen Liu Committed by GitHub
Browse files

[doc] add walkthrough for implementing and hosting a simple llama wrapper m… (#10093)

parent a1d03892
...@@ -135,6 +135,182 @@ ModelRegistry.models.update(import_new_model_classes()) ...@@ -135,6 +135,182 @@ ModelRegistry.models.update(import_new_model_classes())
launch_server(server_args) launch_server(server_args)
``` ```
## Example: Implementing and Serving a Llama Wrapper Model
Below is an introductory, step-by-step walkthrough on how to implement a new model end-to-end in SGLang and then run it via the [Offline Engine](https://github.com/sgl-project/sglang/blob/main/docs/basic_usage/offline_engine_api.ipynb).
### Implementing Our Model
To keep things simple, this new model will be a simple wrapper around [Llama 3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct), and our goal will be just to bias the output logits for each `forward` call by taking the square root of each individual logit.
Let's start by defining our model in a file called `llama_wrapper.py`.
The first step is to import the necessary libraries from SRT, which is SGLang's internal backend.
```python
# In the file `llama_wrapper.py`
import torch
from transformers import LlamaConfig
from typing import Optional
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.models.llama import LlamaForCausalLM
```
Next, we declare a new `class` for our model and have it inherit from `LlamaForCausalLM`, which allows our model to access `LlamaForCausalLM`'s predefined modules and layers, such as `LlamaAttention` and `LlamaMLP`.
Note that almost all model implementations take in `config` and `quant_config` as arguments for their `__init__` method; `config` and `quant_config` are passed in via [`model_loader/loader.py`](https://github.com/sgl-project/sglang/blob/bf72b80122fd888bf619d17b96fa3e323ab809fc/python/sglang/srt/model_loader/loader.py#L219).
Because we have inherited from `LlamaForCausalLM`, we can pass our parameters directly to its constructor, which will set the member variables for us.
```python
class LlamaWrapper(LlamaForCausalLM):
def __init__(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config=config, quant_config=quant_config, prefix=prefix)
```
Now, we want to define the `forward` method, which is what will be called at inference time.
Note that the signature for `forward` is essentially the same for any model; you can take a look at the other models defined in the [`models` directory](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/) for references.
To see where exactly `forward` is called in the SGLang runtime's internals, take a look at [`forward_decode`](https://github.com/sgl-project/sglang/blob/bf72b80122fd888bf619d17b96fa3e323ab809fc/python/sglang/srt/model_executor/model_runner.py#L1705) and [`forward_extend`](https://github.com/sgl-project/sglang/blob/bf72b80122fd888bf619d17b96fa3e323ab809fc/python/sglang/srt/model_executor/model_runner.py#L1724) in the [`ModelRunner` class](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/model_executor/model_runner.py).
```python
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
input_embeds: Optional[torch.Tensor] = None,
get_embedding: bool = False,
) -> LogitsProcessorOutput:
```
We now call the `__call__` method for `self.model` (which is a member variable that `LlamaForCausalLM` defines in its `__init__` method), which eventually calls `LlamaForCausalLM`'s `forward` method.
After that, we feed the `hidden_states` into our model's `LogitsProcessor` (again defined in `LlamaForCausalLM`).
```python
hidden_states = self.model(
input_ids,
positions,
forward_batch,
input_embeds,
pp_proxy_tensors=pp_proxy_tensors,
)
res: LogitsProcessorOutput = self.logits_processor(
input_ids,
hidden_states,
self.lm_head,
forward_batch,
)
```
After receiving the logits for the next token, we can finally perform our biasing step.
```python
orig_logits = res.next_token_logits
res.next_token_logits = torch.where(
orig_logits > 0,
orig_logits.sqrt(),
orig_logits
)
return res
```
Now, our `LlamaWrapper` model is created and ready to be served!
### Serving Our Model Via SGLang's Offline Engine
The next step of this walkthrough involves hosting our new model offline, so that it can be served locally and without an HTTP server.
First, create a new file called `run.py`.
Now, we must ensure that SGLang's `ModelRegistry` can find our model.
To do this, we first download the model's configuration and weights from Huggingface.
```python
# In the file `run.py`
import asyncio
from functools import lru_cache
from huggingface_hub import snapshot_download
from llama_wrapper import LlamaWrapper # Make sure to import our new model!
import sglang as sgl
from sglang.srt.models.registry import ModelRegistry
# Make sure to request access to this model on Huggingface, then export your
# `HF_TOKEN` to download the model snapshot
llama_dir = snapshot_download(
repo_id="meta-llama/Llama-3.1-8B-Instruct",
local_dir="./llama_ckpt",
)
```
Now that we have our model on disk, we want to point it to `LlamaWrapper` by changing the `architectures` field in `./llama_ckpt/config.json` to be `LlamaWrapper`.
That way, when we pass in the path of our model checkpoint to SGLang, it will know that we want to use "LlamaWrapper" instead of "LlamaForCausalLM" as our model.
```python
{
"architectures": [
# "LlamaForCausalLM"
"LlamaWrapper"
],
...
}
```
However, if we don't link our `LlamaWrapper` class to the "LlamaWrapper" registry keyword, then SGLang won't be able to find our model.
Thus, to register our `LlamaWrapper`, we want to follow the steps in the above section titled "Registering an External Model Implementation".
```python
@lru_cache()
def import_new_model_classes():
model_arch_name_to_cls = {"LlamaWrapper": LlamaWrapper}
return model_arch_name_to_cls
ModelRegistry.models.update(import_new_model_classes())
```
Lastly, when we create our `Engine`, we just pass in the path to the local model directory.
Then, our `LlamaWrapper` is ready to be served; for this walkthrough, we will use SGLang `Engine`'s non-streaming asynchronous generation endpoint.
```python
def main():
llm = sgl.Engine(model_path="./llama_ckpt")
sampling_params = {"temperature": 0.2, "top_k": 5}
prompts = [
"Write a short, neutral self-introduction for a fictional character. Hello, my name is",
"Provide a concise factual statement about France’s capital city. The capital of France is",
"Explain possible future trends in artificial intelligence. The future of AI is",
]
asyncio.run(run_llm(llm, sampling_params, prompts))
llm.shutdown()
async def run_llm(
llm,
sampling_params,
prompts,
) -> None:
outputs = await llm.async_generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
print(f"\nPrompt: {prompt}")
print(f"Generated text: {output['text']}")
if __name__ == "__main__":
main()
```
Now, when we call `python run.py`, we will get the outputs of our newly created model!
## Documentation ## Documentation
Add to table of supported models in [generative_models.md](https://github.com/sgl-project/sglang/blob/main/docs/supported_models/generative_models.md) or [multimodal_language_models.md](https://github.com/sgl-project/sglang/blob/main/docs/supported_models/multimodal_language_models.md) Add to table of supported models in [generative_models.md](https://github.com/sgl-project/sglang/blob/main/docs/supported_models/generative_models.md) or [multimodal_language_models.md](https://github.com/sgl-project/sglang/blob/main/docs/supported_models/multimodal_language_models.md)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment