base.yaml 2.76 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# @package _global_
defaults:
  - override /trainer: default # choose trainer from 'configs/trainer/'
  - override /model: null
  - override /datamodule: thepile
  - override /optimizer: adamw-apex  # slight speedup (1-2%) over Pytorch AdamW
  - override /scheduler: cosine-warmup-timm
  - override /callbacks: [default, norm-monitor]
  - override /metrics: [perplexity, num-tokens]
  - override /logger: wandb

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

task:
  _target_: src.tasks.seq.SequenceLMModel

seed: 1111

trainer:
  accelerator: gpu
  devices: 8
  num_nodes: 1
  accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${datamodule.batch_size} * ${trainer.num_nodes}}}
  max_steps: 800000
  val_check_interval: ${eval:2000 * ${.accumulate_grad_batches}}
  check_val_every_n_epoch: null  # We don't care about epoch boundary
  precision: bf16
  gradient_clip_val: 1.0
  strategy: null

datamodule:
  batch_size: 16  # Per GPU
  batch_size_eval: ${.batch_size}  # Fused dense only support batch size at most 64k
  max_length: 2048
  fault_tolerant: True
  ddp: ${eval:"${trainer.devices} > 1"}

train:
  gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"}
  global_batch_size: 256
  optimizer:
    lr: 6e-4
    weight_decay: 0.1
  optimizer_param_grouping:
    bias_weight_decay: False
    normalization_weight_decay: False
  scheduler:
    t_in_epochs: False
    t_initial: 600000
    warmup_lr_init: 1e-6
    warmup_t: ${eval:0.01 * ${trainer.max_steps}}
    lr_min: ${eval:0.1 * ${train.optimizer.lr}}
  loss_fn:
    # This is faster and uses less memory than torch.nn.CrossEntropyLoss.
    # It's also more numerically stable if we're using DeepSpeed 16 bits.
57
    _target_: flash_attn.losses.cross_entropy.CrossEntropyLoss
Tri Dao's avatar
Tri Dao committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    inplace_backward: True  # to save memory

eval:
  log_on_step: True  # 1 training epoch takes too long, we want to see metrics per train step

callbacks:
  model_checkpoint:
    monitor: val/loss
    mode: min
    save_top_k: 3
    save_last: True
    every_n_train_steps: 1000
    dirpath: ${work_dir}/checkpoints/${oc.select:name,''}
    filename: step_{step}
    auto_insert_metric_name: False
  model_checkpoint_progress:
    _target_: src.callbacks.model_checkpoint.ModelCheckpointMine
    # fault_tolerant: True  # The .pl_auto_save.ckpt doesn't get saved by all workers
    every_n_train_steps: 50000
    save_last: False
    save_top_k: -1  # Save all the checkpoints
    dirpath: ${..model_checkpoint.dirpath}
    filename: progress_step_{step}
    auto_insert_metric_name: False
  early_stopping: null