--- title: "RLHF (Beta)" description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback." back-to-top-navigation: true toc: true toc-expand: 2 toc-depth: 4 --- ## Overview Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback. Various methods include, but not limited to: - [Direct Preference Optimization (DPO)](#dpo) - [Identity Preference Optimization (IPO)](#ipo) - [Kahneman-Tversky Optimization (KTO)](#kto) - [Odds Ratio Preference Optimization (ORPO)](#orpo) - Proximal Policy Optimization (PPO) (not yet supported in axolotl, if you're interested in contributing, please reach out!) ## RLHF using Axolotl ::: {.callout-important} This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality. ::: We rely on the [TRL](https://github.com/huggingface/trl) library for implementations of various RL training methods, which we wrap around to expose in axolotl. Each method has their own supported ways of loading datasets and prompt formats. ::: {.callout-tip} You can find what each method supports by going into `src/axolotl/prompt_strategies/{method}` where `{method}` is one of our supported methods. The `type: ` can be retrieved from `{method}.{function_name}`. ::: ### DPO Example config: ```yaml rl: dpo datasets: - path: Intel/orca_dpo_pairs split: train type: chatml.intel - path: argilla/ultrafeedback-binarized-preferences split: train type: chatml ``` DPO supports the following types with the following dataset format: #### chatml.argilla ```json { "system": "...", // optional "instruction": "...", "chosen_response": "...", "rejected_response": "..." } ``` #### chatml.argilla_chat ```json { "chosen": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ], "rejected": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ] } ``` #### chatml.icr ```json { "system": "...", // optional "input": "...", "chosen": "...", "rejected": "..." } ``` #### chatml.intel ```json { "system": "...", // optional "question": "...", "chosen": "...", "rejected": "..." } ``` #### chatml.prompt_pairs ```json { "system": "...", // optional "prompt": "...", "chosen": "...", "rejected": "..." } ``` #### chatml.ultra ```json { "system": "...", // optional "prompt": "...", "chosen": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ], "rejected": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ] } ``` #### llama3.argilla ```json { "system": "...", // optional "instruction": "...", "chosen_response": "...", "rejected_response": "..." } ``` #### llama3.argilla_chat ```json { "chosen": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ], "rejected": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ] } ``` #### llama3.icr ```json { "system": "...", // optional "input": "...", "chosen": "...", "rejected": "..." } ``` #### llama3.intel ```json { "system": "...", // optional "question": "...", "chosen": "...", "rejected": "..." } ``` #### llama3.prompt_pairs ```json { "system": "...", // optional "prompt": "...", "chosen": "...", "rejected": "..." } ``` #### llama3.ultra ```json { "system": "...", // optional "prompt": "...", "chosen": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ], "rejected": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ] } ``` #### zephyr.nectar ```json { "prompt": "...", "answers": [ { "answer": "...", "rank": 1 }, { "answer": "...", "rank": 2 } // ... more answers with ranks ] } ``` #### chat_template.default ```yaml rl: dpo datasets: - path: ... split: train type: chat_template.default field_messages: "messages" field_chosen: "chosen" field_rejected: "rejected" message_property_mappings: role: role content: content roles: user: ["user"] assistant: ["assistant"] system: ["system"] ``` Sample input format: ```json { "messages": [ { "role": "system", "content": "..." }, { "role": "user", "content": "..." }, // ... more messages ], "chosen": { "role": "assistant", "content": "..." }, "rejected": { "role": "assistant", "content": "..." } } ``` #### user_defined.default For custom behaviors, ```yaml rl: dpo datasets: - path: ... split: train type: user_defined.default field_prompt: "prompt" field_system: "system" field_chosen: "chosen" field_rejected: "rejected" prompt_format: "{prompt}" chosen_format: "{chosen}" rejected_format: "{rejected}" ``` The input format is a simple JSON input with customizable fields based on the above config. ```json { "system": "...", // optional "prompt": "...", "chosen": "...", "rejected": "..." } ``` ### IPO As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO. ```yaml rl: ipo ``` ### ORPO Paper: https://arxiv.org/abs/2403.07691 ```yaml rl: orpo orpo_alpha: 0.1 remove_unused_columns: false chat_template: chatml datasets: - path: argilla/ultrafeedback-binarized-preferences-cleaned type: chat_template.argilla ``` ORPO supports the following types with the following dataset format: #### chat_template.argilla ```json { "system": "...", // optional "prompt": "...", // if available, will be taken as user message for single-turn instead of from list below // chosen/rejected should be same till last content and only even-number of alternating user/assistant turns "chosen": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ], "rejected": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ] } ``` ### KTO ```yaml rl: kto rl_beta: 0.1 # default kto_desirable_weight: 1.0 # default kto_undesirable_weight: 1.0 # default remove_unused_columns: false datasets: - path: argilla/ultrafeedback-binarized-preferences-cleaned-kto type: llama3.ultra split: train gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: true ``` KTO supports the following types with the following dataset format: #### chatml.argilla ```json { "system": "...", // optional "instruction": "...", "completion": "..." } ``` #### chatml.argilla_chat ```json { "chosen": [ {"role": "user", "content": "..."} ], "completion": [ {"role": "assistant", "content": "..."} ] } ``` #### chatml.intel ```json { "system": "...", // optional "question": "...", "completion": "..." } ``` #### chatml.prompt_pairs ```json { "system": "...", // optional "prompt": "...", "completion": "..." } ``` #### chatml.ultra ```json { "system": "...", // optional "prompt": "...", "completion": "..." } ``` #### llama3.argilla ```json { "system": "...", // optional "instruction": "...", "completion": "..." } ``` #### llama3.argilla_chat ```json { "completion": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ] } ``` #### llama3.intel ```json { "system": "...", // optional "question": "...", "completion": "..." } ``` #### llama3.prompt_pairs ```json { "system": "...", // optional "prompt": "...", "completion": "..." } ``` #### llama3.ultra ```json { "system": "...", // optional "prompt": "...", "completion": "..." } ``` #### user_defined.default For custom behaviors, ```yaml rl: kto datasets: - path: ... split: train type: user_defined.default field_prompt: "prompt" field_system: "system" field_completion: "completion" field_label: "label" prompt_format: "{prompt}" completion_format: "{completion}" ``` The input format is a simple JSON input with customizable fields based on the above config. ```json { "system": "...", // optional "prompt": "...", "completion": "...", "label": "..." } ``` ### GRPO ::: {.callout-tip} Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo). ::: In the latest GRPO implementation, `vLLM` is used to significantly speedup trajectory generation during training. In this example, we're using 4 GPUs - 2 for training, and 2 for vLLM: ::: {.callout-important} Make sure you've installed the correct version of vLLM by including it as an extra when installing axolotl, e.g. `pip install axolotl[vllm]`. ::: ```yaml base_model: Qwen/Qwen2.5-1.5B-Instruct vllm: host: 0.0.0.0 port: 8000 tensor_parallel_size: 2 gpu_memory_utilization: 0.85 dtype: auto # max_model_len: # you may find it useful to set the vLLM model context length if you know this beforehand rl: grpo trl: use_vllm: true vllm_server_host: 0.0.0.0 vllm_server_port: 8000 vllm_server_timeout: 300 ``` ```bash CUDA_VISIBLE_DEVICES=2,3 axolotl vllm-serve grpo.yaml ``` Your `vLLM` instance will now attempt to spin up, and it's time to kick off training utilizing our remaining two GPUs. In another terminal, execute: ```bash CUDA_VISIBLE_DEVICES=0,1 axolotl train grpo.yaml --num-processes 2 ``` ::: {.callout-note} Due to TRL's implementation with vLLM, the vLLM instance must use the last N GPUs instead of the first N GPUs. This is why in the example above, we use `CUDA_VISIBLE_DEVICES=2,3` for the vLLM instance. ::: #### Reward functions GRPO uses custom reward functions and transformations. Please have them ready locally. For example, to load OpenAI's GSM8K and use a random reward for completions: ```python # rewards.py import random def rand_reward_func(completions, **kwargs) -> list[float]: return [random.uniform(0, 1) for _ in completions] def oai_gsm8k_transform(cfg, *args, **kwargs): def transform_fn(example, tokenizer=None): label = example["answer"].split("####")[-1].strip().replace(",", "") return { "prompt": [{"role": "user", "content": example["question"]},], "answer": label, } return transform_fn, {"remove_columns": ["question"]} ``` ```yaml rl: grpo trl: beta: 0.001 max_completion_length: 256 use_vllm: True num_generations: 4 reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}' reward_weights: [1.0] datasets: - path: openai/gsm8k name: main type: rewards.oai_gsm8k_transform # format: '{file_name}.{fn_name}' ``` To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function). To see all configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/v0.9.2/src/axolotl/utils/schemas/trl.py). #### GRPO with DAPO/Dr. GRPO loss The DAPO paper and subsequently Dr. GRPO paper proposed an alternative loss function for GRPO to remediate the penalty in longer responses. ```yaml trl: loss_type: dr_grpo # Normalizes loss based on max completion length (default: 256) max_completion_length: ``` For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types). ### SimPO SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function. ```yaml rl: simpo rl_beta: 0.1 # default in CPOTrainer cpo_alpha: 1.0 # default in CPOTrainer simpo_gamma: 0.5 # default in CPOTrainer ``` This method uses the same dataset format as [DPO](#dpo). ### Using local dataset files ```yaml datasets: - ds_type: json data_files: - orca_rlhf.jsonl split: train type: chatml.intel ``` ### TRL auto-unwrapping for PEFT TRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config: ```yaml # load ref model when adapter training. rl_adapter_ref_model: true ```