rl_qwen3-4b.yaml 1.96 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# http://docs.axolotl.ai/docs/rlhf.html#grpo
# https://github.com/axolotl-ai-cloud/axolotl-cookbook/blob/main/grpo/gsm8k.yaml

base_model: Qwen/Qwen3-4B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

load_in_8bit: false
load_in_4bit: false
strict: false

torch_compile: false

vllm:
    host: 0.0.0.0
    port: 8000
    tensor_parallel_size: 2
    gpu_memory_utilization: 0.85
    dtype: auto
    # max_model_len: # you may find it useful to set the vLLM model context length if you know this beforehand

chenzk's avatar
v1.0.3  
chenzk committed
22
rl: grpo 
chenzk's avatar
v1.0  
chenzk committed
23
trl:
chenzk's avatar
v1.0.3  
chenzk committed
24
  # loss_type: dr_grpo # dapo
chenzk's avatar
v1.0  
chenzk committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
  use_vllm: true
  vllm_server_host: localhost
  vllm_server_port: 8000
  vllm_server_timeout: 300
  beta: 0.001
  max_completion_length: 512
  use_vllm: true
  reward_funcs:
    - gsm8k_grpo.correctness_reward_func
    - gsm8k_grpo.int_reward_func
    - gsm8k_grpo.strict_format_reward_func
    - gsm8k_grpo.soft_format_reward_func
    - gsm8k_grpo.xmlcount_reward_func
  vllm_gpu_memory_utilization: 0.9
  vllm_max_model_len: 256
  num_generations: 2

chat_template: qwen3
datasets:
  - path: skrishna/gsm8k_only_answer # export PYTHONPATH=/home/axolotl/axolotl-cookbook/grpo/:$PYTHONPATH
    type: gsm8k_grpo.axo_gsm8k_transform
dataset_prepared_path:
skip_prepare_dataset: true
val_set_size: 0.0
output_dir: ./outputs/out

dataloader_prefetch_factor: 32
dataloader_num_workers: 2
dataloader_pin_memory: true

gc_steps: 1

sequence_len: 256
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: false

wandb_project:
wandb_entity:
wandb_name:

gradient_accumulation_steps: 8
micro_batch_size: 2  # should match num_generations / num_gpus
num_epochs: 1

optimizer: adamw_torch_fused
lr_scheduler: constant_with_warmup
learning_rate: 1.0e-6
max_grad_norm: 1.0
weight_decay: 0.1

bf16: auto
tf32: false

gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: false
flash_attention: true

logging_steps: 1
warmup_steps: 100
evals_per_epoch: 1
saves_per_epoch: 4