# http://docs.axolotl.ai/docs/rlhf.html#grpo
# https://github.com/axolotl-ai-cloud/axolotl-cookbook/blob/main/grpo/gsm8k.yaml
base_model: Qwen/Qwen3-4B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
torch_compile: false
vllm:
host: 0.0.0.0
port: 8000
tensor_parallel_size: 2
gpu_memory_utilization: 0.85
dtype: auto
# max_model_len: # you may find it useful to set the vLLM model context length if you know this beforehand
rl: grpo
trl:
# loss_type: dr_grpo # dapo
use_vllm: true
vllm_server_host: localhost
vllm_server_port: 8000
vllm_server_timeout: 300
beta: 0.001
max_completion_length: 512
use_vllm: true
reward_funcs:
- gsm8k_grpo.correctness_reward_func
- gsm8k_grpo.int_reward_func
- gsm8k_grpo.strict_format_reward_func
- gsm8k_grpo.soft_format_reward_func
- gsm8k_grpo.xmlcount_reward_func
vllm_gpu_memory_utilization: 0.9
vllm_max_model_len: 256
num_generations: 2
chat_template: qwen3
datasets:
- path: skrishna/gsm8k_only_answer # export PYTHONPATH=/home/axolotl/axolotl-cookbook/grpo/:$PYTHONPATH
type: gsm8k_grpo.axo_gsm8k_transform
dataset_prepared_path:
skip_prepare_dataset: true
val_set_size: 0.0
output_dir: ./outputs/out
dataloader_prefetch_factor: 32
dataloader_num_workers: 2
dataloader_pin_memory: true
gc_steps: 1
sequence_len: 256
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: false
wandb_project:
wandb_entity:
wandb_name:
gradient_accumulation_steps: 8
micro_batch_size: 2 # should match num_generations / num_gpus
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: constant_with_warmup
learning_rate: 1.0e-6
max_grad_norm: 1.0
weight_decay: 0.1
bf16: auto
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
flash_attention: true
logging_steps: 1
warmup_steps: 100
evals_per_epoch: 1
saves_per_epoch: 4