This folder contains instructions to fine-tune Meta Llama 3 on a
*[single-GPU setup](./singlegpu_finetuning.md)
*[multi-GPU setup](./multigpu_finetuning.md)
using the canonical [finetuning script](../../src/llama_recipes/finetuning.py) in the llama-recipes package.
If you are new to fine-tuning techniques, check out an overview: [](./LLM_finetuning_overview.md)
> [!TIP]
> If you want to try finetuning Meta Llama 3 with Huggingface's trainer, here is a Jupyter notebook with an [example](./huggingface_trainer/peft_finetuning.ipynb)
## How to configure finetuning settings?
> [!TIP]
> All the setting defined in [config files](../../src/llama_recipes/configs/) can be passed as args through CLI when running the script, there is no need to change from config files directly.
*[Training config file](../../src/llama_recipes/configs/training.py) is the main config file that helps to specify the settings for our run and can be found in [configs folder](../../src/llama_recipes/configs/)
It lets us specify the training settings for everything from `model_name` to `dataset_name`, `batch_size` and so on. Below is the list of supported settings:
peft_method:str="lora"# None, llama_adapter (Caution: llama_adapter is currently not supported with FSDP)
use_peft:bool=False
from_peft_checkpoint:str=""# if not empty and use_peft=True, will load the peft checkpoint and resume the fine-tuning on that checkpoint
output_dir:str="PATH/to/save/PEFT/model"
freeze_layers:bool=False
num_freeze_layers:int=1
quantization:bool=False
one_gpu:bool=False
save_model:bool=True
dist_checkpoint_root_folder:str="PATH/to/save/FSDP/model"# will be used if using FSDP
dist_checkpoint_folder:str="fine-tuned"# will be used if using FSDP
save_optimizer:bool=False# will be used if using FSDP
use_fast_kernels:bool=False# Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
use_wandb:bool=False# Enable wandb for experient tracking
save_metrics:bool=False# saves training metrics to a json file for later plotting
flop_counter:bool=False# Enable flop counter to measure model throughput, can not be used with pytorch profiler at the same time.
flop_counter_start:int=3# The step to start profiling, default is 3, which means after 3 steps of warmup stage, the profiler will start to count flops.
use_profiler:bool=False# Enable pytorch profiler, can not be used with flop counter at the same time.
profiler_dir:str="PATH/to/save/profiler/results"# will be used if using profiler
```
*[Datasets config file](../../src/llama_recipes/configs/datasets.py) provides the available options for datasets.
*[peft config file](../../src/llama_recipes/configs/peft.py) provides the supported PEFT methods and respective settings that can be modified. We currently support LoRA and Llama-Adapter. Please note that LoRA is the only technique which is supported in combination with FSDP.
*[FSDP config file](../../src/llama_recipes/configs/fsdp.py) provides FSDP settings such as:
* `mixed_precision` boolean flag to specify using mixed precision, defatults to true.
* `use_fp16` boolean flag to specify using FP16 for mixed precision, defatults to False. We recommond not setting this flag, and only set `mixed_precision` that will use `BF16`, this will help with speed and memory savings while avoiding challenges of scaler accuracies with `FP16`.
* `sharding_strategy` this specifies the sharding strategy for FSDP, it can be:
* `FULL_SHARD` that shards model parameters, gradients and optimizer states, results in the most memory savings.
* `SHARD_GRAD_OP` that shards gradinets and optimizer states and keeps the parameters after the first `all_gather`. This reduces communication overhead specially if you are using slower networks more specifically beneficial on multi-node cases. This comes with the trade off of higher memory consumption.
* `NO_SHARD` this is equivalent to DDP, does not shard model parameters, gradinets or optimizer states. It keeps the full parameter after the first `all_gather`.
* `HYBRID_SHARD` available on PyTorch Nightlies. It does FSDP within a node and DDP between nodes. It's for multi-node cases and helpful for slower networks, given your model will fit into one node.
*`checkpoint_type` specifies the state dict checkpoint type for saving the model. `FULL_STATE_DICT` streams state_dict of each model shard from a rank to CPU and assembels the full state_dict on CPU. `SHARDED_STATE_DICT` saves one checkpoint per rank, and enables the re-loading the model in a different world size.
*`fsdp_activation_checkpointing` enables activation checkpoining for FSDP, this saves significant amount of memory with the trade off of recomputing itermediate activations during the backward pass. The saved memory can be re-invested in higher batch sizes to increase the throughput. We recommond you use this option.
*`pure_bf16` it moves the model to `BFloat16` and if `optimizer` is set to `anyprecision` then optimizer states will be kept in `BFloat16` as well. You can use this option if necessary.
## Weights & Biases Experiment Tracking
You can enable [W&B](https://wandb.ai/) experiment tracking by using `use_wandb` flag as below. You can change the project name, entity and other `wandb.init` arguments in `wandb_config`.
To help with benchmarking effort, we are adding the support for counting the FLOPS during the fine-tuning process. You can achieve this by setting `--flop_counter` when launching your single/multi GPU fine-tuning. Use `--flop_counter_start` to choose which step to count the FLOPS. It is recommended to allow a warm-up stage before using the FLOPS counter.
Similarly, you can set `--use_profiler` flag and pass a profiling output path using `--profiler_dir` to capture the profile traces of your model using [PyTorch profiler](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html). To get accurate profiling result, the pytorch profiler requires a warm-up stage and the current config is wait=1, warmup=2, active=3, thus the profiler will start the profiling after step 3 and will record the next 3 steps. Therefore, in order to use pytorch profiler, the --max-train-step has been greater than 6. The pytorch profiler would be helpful for debugging purposes. However, the `--flop_counter` and `--use_profiler` can not be used in the same time to ensure the measurement accuracy.
The provided fine tuning scripts allows you to select between three datasets by passing the `dataset` arg to the `llama_recipes.finetuning` module or [`recipes/finetuning/finetuning.py`](../finetuning.py) script. The current options are `grammar_dataset`, `alpaca_dataset`and `samsum_dataset`. Additionally, we integrate the OpenAssistant/oasst1 dataset as an [example for a custom dataset](custom_dataset.py) Note: Use of any of the datasets should be in compliance with the dataset's underlying licenses (including but not limited to non-commercial uses)
*[grammar_dataset](https://huggingface.co/datasets/jfleg) contains 150K pairs of english sentences and possible corrections.
*[alpaca_dataset](https://github.com/tatsu-lab/stanford_alpaca) provides 52K instruction-response pairs as generated by `text-davinci-003`.
*[samsum_dataset](https://huggingface.co/datasets/samsum) contains about 16k messenger-like conversations with summaries.
*[OpenAssistant/oasst1](https://huggingface.co/datasets/OpenAssistant/oasst1/) contains about 88k messages from assistant-style conversations.
## Batching Strategies
Llama-recipes support two strategies to batch requests together.
The default setting is `packing` which concatenates the tokenized samples into long sequences filling up the context length of the model.
This is the most compute efficient variant as it avoids any padding and all sequences have the same length.
Samples at the boundary of the context length are truncated and the remainder of the cut sequence it used as the start of the next long sequence.
If the amount of training data is small this procedure might introduce a lot of noise into the training data which can hurt the prediction performance of the fine-tune model.
Therefore, we also support a `padding` strategy which does not introduce the addition noise due to truncated sequences.
The strategy tries to minimize the efficiency loss by batching samples of similar length together so only minimal padding is necessary.
The batching strategy can be selected though the command line parameter `--batching_strategy [packing]/[padding]`.
## Using custom datasets
The list of available datasets in llama-recipes is supposed to give users a quick start on training their Llama model.
To use a custom dataset there are two possible ways.
The first provides a function returning the dataset in a .py file which can be given to the command line tool.
This does not involve changing the source code of llama-recipes.
The second way is targeting contributions which extend llama-recipes as it involves changing the source code.
### Training on custom data
To supply a custom dataset you need to provide a single .py file which contains a function with the following signature:
For an example `get_custom_dataset` you can look at the provided datasets in llama_recipes.datasets or [examples/custom_dataset.py](custom_dataset.py).
The `dataset_config` in the above signature will be an instance of llama_recipes.configs.dataset.custom_dataset with the modifications made through the command line.
The split signals wether to return the training or validation dataset.
The default function name is `get_custom_dataset` but this can be changed as described below.
In order to start a training with the custom dataset we need to set the `--dataset` as well as the `--custom_dataset.file` parameter.
This will call the function `get_foo` instead of `get_custom_dataset` when retrieving the dataset.
### Adding new dataset
Each dataset has a corresponding configuration (dataclass) in [configs/datasets.py](../../../src/llama_recipes/configs/datasets.py) which contains the dataset name, training/validation split names, as well as optional parameters like datafiles etc.
Additionally, there is a preprocessing function for each dataset in the [datasets](../../../src/llama_recipes/datasets) folder.
The returned data of the dataset needs to be consumable by the forward method of the fine-tuned model by calling ```model(**data)```.
For CausalLM models this usually means that the data needs to be in the form of a dictionary with "input_ids", "attention_mask" and "labels" fields.
To add a custom dataset the following steps need to be performed.
1. Create a dataset configuration after the schema described above. Examples can be found in [configs/datasets.py](../../../src/llama_recipes/configs/datasets.py).
2. Create a preprocessing routine which loads the data and returns a PyTorch style dataset. The signature for the preprocessing function needs to be (dataset_config, tokenizer, split_name) where split_name will be the string for train/validation split as defined in the dataclass.
3. Register the dataset name and preprocessing function by inserting it as key and value into the DATASET_PREPROC dictionary in [utils/dataset_utils.py](../../../src/llama_recipes/utils/dataset_utils.py)
4. Set dataset field in training config to dataset name or use --dataset option of the `llama_recipes.finetuning` module or examples/finetuning.py training script.
## Application
Below we list other datasets and their main use cases that can be used for fine tuning.
This recipe steps you through how to finetune a Meta Llama 3 model on the text summarization task using the [samsum](https://huggingface.co/datasets/samsum) dataset on multiple GPUs in a single or across multiple nodes.
## Requirements
Ensure that you have installed the llama-recipes package ([details](../../README.md#installing)).
We will also need 2 packages:
1.[PEFT](https://github.com/huggingface/peft) to use parameter-efficient finetuning.
2.[FSDP](https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html) which helps us parallelize the training over multiple GPUs. [More details](./LLM_finetuning_overview.md#2-full-partial-parameter-finetuning).
> [!NOTE]
> The llama-recipes package will install PyTorch 2.0.1 version. In case you want to use FSDP with PEFT for multi GPU finetuning, please install the PyTorch nightlies ([details](../../README.md#pytorch-nightlies))
>
> INT8 quantization is not currently supported in FSDP
## How to run it
Get access to a machine with multiple GPUs (in this case we tested with 4 A100 and A10s).
Here we use a slurm script to schedule a job with slurm over multiple nodes.
# Change the num nodes and GPU per nodes in the script before running.
sbatch ./multi_node.slurm
</details>
We use `torchrun` to spawn multiple processes for FSDP.
The args used in the command above are:
*`--enable_fsdp` boolean flag to enable FSDP in the script
*`--use_peft` boolean flag to enable PEFT methods in the script
*`--peft_method` to specify the PEFT method, here we use `lora` other options are `llama_adapter`, `prefix`.
### With only FSDP
If interested in running full parameter finetuning without making use of PEFT methods, please use the following command. Make sure to change the `nproc_per_node` to your available GPUs. This has been tested with `BF16` on 8xA100, 40GB GPUs.
If you are running full parameter fine-tuning on the 70B model, you can enable `low_cpu_fsdp` mode as the following command. This option will load model on rank0 only before moving model to devices to construct FSDP. This can dramatically save cpu memory when loading large models like 70B (on a 8-gpu node, this reduces cpu memory from 2+T to 280G for 70B model). This has been tested with `BF16` on 16xA100, 80GB GPUs.
Currently 3 open source datasets are supported that can be found in [Datasets config file](../../src/llama_recipes/configs/datasets.py). You can also use your custom dataset (more info [here](./datasets/README.md)).
*`grammar_dataset` : use this [notebook](../../src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb) to pull and process the Jfleg and C4 200M datasets for grammar checking.
*`alpaca_dataset` : to get this open source data please download the `aplaca.json` to `dataset` folder.
In case you are dealing with slower interconnect network between nodes, to reduce the communication overhead you can make use of `--hsdp` flag.
HSDP (Hybrid sharding Data Parallel) helps to define a hybrid sharding strategy where you can have FSDP within `sharding_group_size` which can be the minimum number of GPUs you can fit your model and DDP between the replicas of the model specified by `replica_group_size`.
This will require to set the Sharding strategy in [fsdp config](../../src/llama_recipes/configs/fsdp.py) to `ShardingStrategy.HYBRID_SHARD` and specify two additional settings, `sharding_group_size` and `replica_group_size` where former specifies the sharding group size, number of GPUs that you model can fit into to form a replica of a model and latter specifies the replica group size, which is world_size/sharding_group_size.
To help with benchmarking effort, we are adding the support for counting the FLOPS during the fine-tuning process. You can achieve this by setting `--flop_counter` when launching your single/multi GPU fine-tuning. Use `--flop_counter_start` to choose which step to count the FLOPS. It is recommended to allow a warm-up stage before using the FLOPS counter.
Similarly, you can set `--use_profiler` flag and pass a profiling output path using `--profiler_dir` to capture the profile traces of your model using [PyTorch profiler](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html). To get accurate profiling result, the pytorch profiler requires a warm-up stage and the current config is wait=1, warmup=2, active=3, thus the profiler will start the profiling after step 3 and will record the next 3 steps. Therefore, in order to use pytorch profiler, the --max-train-step has been greater than 6. The pytorch profiler would be helpful for debugging purposes. However, the `--flop_counter` and `--use_profiler` can not be used in the same time to ensure the measurement accuracy.
This recipe steps you through how to finetune a Meta Llama 3 model on the text summarization task using the [samsum](https://huggingface.co/datasets/samsum) dataset on a single GPU.
These are the instructions for using the canonical [finetuning script](../../src/llama_recipes/finetuning.py) in the llama-recipes package.
## Requirements
Ensure that you have installed the llama-recipes package ([details](../../README.md#installing)).
To run fine-tuning on a single GPU, we will make use of two packages:
1.[PEFT](https://github.com/huggingface/peft) to use parameter-efficient finetuning.
2.[bitsandbytes](https://github.com/TimDettmers/bitsandbytes) for int8 quantization.
*`--use_peft` boolean flag to enable PEFT methods in the script
*`--peft_method` to specify the PEFT method, here we use `lora` other options are `llama_adapter`, `prefix`.
*`--quantization` boolean flag to enable int8 quantization
> [!NOTE]
> In case you are using a multi-GPU machine please make sure to only make one of them visible using `export CUDA_VISIBLE_DEVICES=GPU:id`.
### How to run with different datasets?
Currently 3 open source datasets are supported that can be found in [Datasets config file](../../src/llama_recipes/configs/datasets.py). You can also use your custom dataset (more info [here](./datasets/README.md)).
*`grammar_dataset` : use this [notebook](../../src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb) to pull and process the Jfleg and C4 200M datasets for grammar checking.
*`alpaca_dataset` : to get this open source data please download the `alpaca.json` to `dataset` folder.
To help with benchmarking effort, we are adding the support for counting the FLOPS during the fine-tuning process. You can achieve this by setting `--flop_counter` when launching your single/multi GPU fine-tuning. Use `--flop_counter_start` to choose which step to count the FLOPS. It is recommended to allow a warm-up stage before using the FLOPS counter.
Similarly, you can set `--use_profiler` flag and pass a profiling output path using `--profiler_dir` to capture the profile traces of your model using [PyTorch profiler](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html). To get accurate profiling result, the pytorch profiler requires a warm-up stage and the current config is wait=1, warmup=2, active=3, thus the profiler will start the profiling after step 3 and will record the next 3 steps. Therefore, in order to use pytorch profiler, the --max-train-step has been greater than 6. The pytorch profiler would be helpful for debugging purposes. However, the `--flop_counter` and `--use_profiler` can not be used in the same time to ensure the measurement accuracy.
For local inference we have provided an [inference script](inference.py). Depending on the type of finetuning performed during training the [inference script](inference.py) takes different arguments.
To finetune all model parameters the output dir of the training has to be given as --model_name argument.
In the case of a parameter efficient method like lora the base model has to be given as --model_name and the output dir of the training has to be given as --peft_model argument.
Additionally, a prompt for the model in the form of a text file has to be provided. The prompt file can either be piped through standard input or given as --prompt_file parameter.
**Content Safety**
The inference script also supports safety checks for both user prompt and model outputs. In particular, we use two packages, [AuditNLG](https://github.com/salesforce/AuditNLG/tree/main) and [Azure content safety](https://pypi.org/project/azure-ai-contentsafety/1.0.0b1/).
**Note**
If using Azure content Safety, please make sure to get the endpoint and API key as described [here](https://pypi.org/project/azure-ai-contentsafety/1.0.0b1/) and add them as the following environment variables,`CONTENT_SAFETY_ENDPOINT` and `CONTENT_SAFETY_KEY`.
The folder contains test prompts for summarization use-case:
```
samsum_prompt.txt
...
```
**Note**
Currently pad token by default in [HuggingFace Tokenizer is `None`](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py#L110). We add the padding token as a special token to the tokenizer, which in this case requires to resize the token_embeddings as shown below:
Padding would be required for batch inference. In this this [example](inference.py), batch size = 1 so essentially padding is not required. However,We added the code pointer as an example in case of batch inference.
## Chat completion
The inference folder also includes a chat completion example, that adds built-in safety features in fine-tuned models to the prompt tokens. To run the example:
## Flash Attention and Xformer Memory Efficient Kernels
Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up inference when used for batched inputs. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/).
In case you have fine-tuned your model with pure FSDP and saved the checkpoints with "SHARDED_STATE_DICT" as shown [here](../../../src/llama_recipes/configs/fsdp.py), you can use this converter script to convert the FSDP Sharded checkpoints into HuggingFace checkpoints. This enables you to use the inference script normally as mentioned above.
**To convert the checkpoint use the following command**:
This is helpful if you have fine-tuned you model using FSDP only as follows:
# --HF_model_path_or_name specifies the HF Llama model name or path where it has config.json and tokenizer.json
```
By default, training parameter are saved in`train_params.yaml`in the path where FSDP checkpoints are saved, in the converter script we frist try to find the HugingFace model name used in the fine-tuning to load the model with configs from there, if not found user need to provide it.
max_new_tokens=256,#The maximum numbers of tokens to generate
min_new_tokens:int=0,#The minimum numbers of tokens to generate
prompt_file:str=None,
seed:int=42,#seed value for reproducibility
safety_score_threshold:float=0.5,
do_sample:bool=True,#Whether or not to use sampling ; use greedy decoding otherwise.
use_cache:bool=True,#[optional] Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
top_p:float=1.0,# [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature:float=1.0,# [optional] The value used to modulate the next token probabilities.
top_k:int=50,# [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering.
repetition_penalty:float=1.0,#The parameter for repetition penalty. 1.0 means no penalty.
length_penalty:int=1,#[optional] Exponential penalty to the length that is used with beam-based generation.
enable_azure_content_safety:bool=False,# Enable safety check with Azure content safety api
enable_sensitive_topics:bool=False,# Enable check for sensitive topics using AuditNLG APIs
[{"role":"user","content":"what is the recipe of mayonnaise?"}],
[
{"role":"user","content":"I am going to Paris, what should I see?"},
{
"role":"assistant",
"content":"Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. Here are some of the top attractions to see in Paris:1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city. 2. The Louvre Museum: The Louvre is one of the world's largest and most famous museums, housing an impressive collection of art and artifacts, including the Mona Lisa. 3. Notre-Dame Cathedral: This beautiful cathedral is one of the most famous landmarks in Paris and is known for its Gothic architecture and stunning stained glass windows.These are just a few of the many attractions that Paris has to offer. With so much to see and do, it's no wonder that Paris is one of the most popular tourist destinations in the world."
},
{"role":"user","content":"What is so great about #1?"}
],
[
{"role":"system","content":"Always answer with Haiku"},
{"role":"user","content":"I am going to Paris, what should I see?"}
],
[
{
"role":"system",
"content":"Always answer with emojis"
},
{"role":"user","content":"How to go from Beijing to NY?"}
],
[
{
"role":"system",
"content":"You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
},
{"role":"user","content":"Write a brief birthday message to John"}
max_new_tokens=100,#The maximum numbers of tokens to generate
prompt_file:str=None,
seed:int=42,#seed value for reproducibility
do_sample:bool=True,#Whether or not to use sampling ; use greedy decoding otherwise.
min_length:int=None,#The minimum length of the sequence to be generated, input prompt + min_new_tokens
use_cache:bool=True,#[optional] Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
top_p:float=1.0,# [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature:float=1.0,# [optional] The value used to modulate the next token probabilities.
top_k:int=50,# [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering.
repetition_penalty:float=1.0,#The parameter for repetition penalty. 1.0 means no penalty.
length_penalty:int=1,#[optional] Exponential penalty to the length that is used with beam-based generation.
enable_azure_content_safety:bool=False,# Enable safety check with Azure content safety api
enable_sensitive_topics:bool=False,# Enable check for sensitive topics using AuditNLG APIs
enable_salesforce_content_safety:bool=True,# Enable safety check with Salesforce safety flan t5
enable_llamaguard_content_safety:bool=False,
max_padding_length:int=None,# the max padding length to be used with tokenizer padding the prompts.
use_fast_kernels:bool=False,# Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
# Running Llama3 8B Instruct on Android with MLC-LLM
Author: Thierry Moreau - tmoreau@octo.ai
# Overview
In this tutorial we'll learn how to deploy Llama3 8B Instruct on an Android-based phone using MLC-LLM.
Machine Learning Compilation for Large Language Models (MLC LLM) is a high-performance universal deployment solution that allows native deployment of any large language models with native APIs with compiler acceleration. The mission of this project is to enable everyone to develop, optimize and deploy AI models natively on everyone's devices with ML compilation techniques.
You can read more about MLC-LLM at the following [link](https://github.com/mlc-ai/mlc-llm).
MLC-LLM is also what powers the Llama3 inference APIs provided by [OctoAI](https://octo.ai/). You can use OctoAI for your Llama3 cloud-based inference needs by trying out the examples under the [following path](../../../llama_api_providers/OctoAI_API_examples/).
This tutorial was tested with the following setup:
* MacBook Pro 16 inch from 2021 with Apple M1 Max and 32GB of RAM running Sonoma 14.3.1
* OnePlus 12 Android Smartphone with a Snapdragon 8Gen3 SoC and 12GB or RAM, running OxygenOS 14.0
Running Llama3 on a phone will likely require a powerful chipset. We haven't tested extensively the range of chipset that will support this usecase. Feel free to update this README.md to specify what devices were successfully tested.
This guide is heavily based on the [MLC Android Guide](https://llm.mlc.ai/docs/deploy/android.html), but several steps have been taken to streamline the instructions.
# Pre-requisites
## Python
Whether you're using conda or virtual env to manage your environment, we highly recommend starting from scratch with a clean new environment.
For instance with virtual environment:
```bash
python3 -m venv .venv
source .venv/bin/activate
```
Next you'll need to install the following packages:
```bash
python3 -m pip install-r requirements.txt
```
## Rust
[Rust](https://www.rust-lang.org/tools/install) is needed to cross-compile HuggingFace tokenizers to Android.
Make sure rustc, cargo, and rustup are available in $PATH.
## Android Studio
Install Android Studio from <!-- markdown-link-check-disable -->https://developer.android.com/studio<!-- markdown-link-check-enable --> with NDK and CMake.
To install NDK and CMake, in the Android Studio welcome page, click “Projects → SDK Manager → SDK Tools”. Set up the following environment variables:
* ANDROID_NDK so that $ANDROID_NDK/build/cmake/android.toolchain.cmake is available.
* TVM_NDK_CC that points to NDK's clang compiler.
For instance, the paths will look like the following on OSX for user `moreau`:
This tutorial was tested successfully on Android Studio Hedgehog | 2023.1.1 Patch 1.
## JDK
JDK, such as OpenJDK >= 17, to compile Java bindings of TVM Unity runtime.
We strongly recommend setting the JAVA_HOME to the JDK bundled with Android Studio. Using Android Studio’s JBR bundle as recommended (<!-- markdown-link-check-disable -->https://developer.android.com/build/jdks<!-- markdown-link-check-enable -->) will reduce the chances of potential errors in JNI compilation.
For instance on macOS, you'll need to point JAVA_HOME to the following.
At the time of writing this README, we tested `mlc-llm` at the following sha: `21feb7010db02e0c2149489f5972d6a8a796b5a0`.
## Phone Setup
On your phone, enable debugging on your phone in your phone’s developer settings. Each phone manufacturer will have its own approach to enabling debug mode, so a simple Google search should equip you with the steps to do that on your phone.
In addition, make sure to change your USB configuration from "Charging" to "MTP (Media Transfer Protocol)". This will allow us to connect to the device serially.
Connect your phone to your development machine. On OSX, you'll be prompted on the dev machine whether you want to allow the accessory to connect. Hit "Allow".
# Build Steps
## Building the Android Package with MLC
First edit the file under `android/MLCChat/mlc-package-config.json` and with the [mlc-package-config.json](./mlc-package-config.json) in llama-recipes.
To understand what these JSON fields mean you can refer to this [documentation](https://llm.mlc.ai/docs/deploy/android.html#step-2-build-runtime-and-model-libraries).
The command above will take a few minutes to run as it runs through the following steps:
* Compile the Llama 3 8B instruct specified in the `mlc-package-config.json` into a binary model library.
* Build the `mlc-llm` runtime and tokenizer. In addition to the model itself, a lightweight runtime and tokenizer are required to actually run the LLM.
## Building and Running MLC Chat in Android Studio
Now let's launch Android Studio.
* On the "Welcome to Android Studio" page, hit "Open", and navigate to `$MLC_LLM_HOME/android/MLCChat`, then hit "Open"
* A window will pop up asking whether to "Trust and Open project 'MLCChat'" - hit "Trust Project"
* The project will now launch
* Under File -> Project Structure... -> Project change the Gradle Version (second drop down from the top) to 8.5
Connect your phone to your development machine - assuming you've followed the setup steps in the pre-requisite section, you should be able to see the device.
Next you'll need to:
* Hit Build -> Make Project.
* Hit Run -> Run 'app'
The MLCChat app will launch on your phone, now access your phone:
* Under Model List you'll see the `Llama-3-8B-Instruct` LLM listed.
* The model's not quite ready to launch yet, because the weights need to be downloaded over Wifi first. Hit the Download button on the right to the model name to download the weights from HuggingFace.
Note that you can change the build settings to bundle the weights with the MLCChat app so you don't have to download the weights over wifi. To do so you can follow the instructions [here](https://llm.mlc.ai/docs/deploy/android.html#bundle-model-weights).
Once the model weights are downloaded you can now interact with Llama 3 locally on your Android phone!
## [Running Llama 3 On-Prem with vLLM and TGI](llama-on-prem.md)
This tutorial shows how to use Llama 3 with [vLLM](https://github.com/vllm-project/vllm) and Hugging Face [TGI](https://github.com/huggingface/text-generation-inference) to build Llama 3 on-prem apps.