rl_qwen3-4b.yaml 1.93 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# http://docs.axolotl.ai/docs/rlhf.html#grpo
# https://github.com/axolotl-ai-cloud/axolotl-cookbook/blob/main/grpo/gsm8k.yaml

base_model: Qwen/Qwen3-4B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

load_in_8bit: false
load_in_4bit: false
strict: false

torch_compile: false

vllm:
    host: 0.0.0.0
    port: 8000
    tensor_parallel_size: 2
    gpu_memory_utilization: 0.85
    dtype: auto
    # max_model_len: # you may find it useful to set the vLLM model context length if you know this beforehand

rl: grpo # dr_grpo
trl:
  use_vllm: true
  vllm_server_host: localhost
  vllm_server_port: 8000
  vllm_server_timeout: 300
  beta: 0.001
  max_completion_length: 512
  use_vllm: true
  reward_funcs:
    - gsm8k_grpo.correctness_reward_func
    - gsm8k_grpo.int_reward_func
    - gsm8k_grpo.strict_format_reward_func
    - gsm8k_grpo.soft_format_reward_func
    - gsm8k_grpo.xmlcount_reward_func
  vllm_gpu_memory_utilization: 0.9
  vllm_max_model_len: 256
  num_generations: 2

chat_template: qwen3
datasets:
  - path: skrishna/gsm8k_only_answer # export PYTHONPATH=/home/axolotl/axolotl-cookbook/grpo/:$PYTHONPATH
    type: gsm8k_grpo.axo_gsm8k_transform
dataset_prepared_path:
skip_prepare_dataset: true
val_set_size: 0.0
output_dir: ./outputs/out

dataloader_prefetch_factor: 32
dataloader_num_workers: 2
dataloader_pin_memory: true

gc_steps: 1

sequence_len: 256
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: false

wandb_project:
wandb_entity:
wandb_name:

gradient_accumulation_steps: 8
micro_batch_size: 2  # should match num_generations / num_gpus
num_epochs: 1

optimizer: adamw_torch_fused
lr_scheduler: constant_with_warmup
learning_rate: 1.0e-6
max_grad_norm: 1.0
weight_decay: 0.1

bf16: auto
tf32: false

gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: false
flash_attention: true

logging_steps: 1
warmup_steps: 100
evals_per_epoch: 1
saves_per_epoch: 4