base.yaml 2.8 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# @package _global_
defaults:
  - override /trainer: default # choose trainer from 'configs/trainer/'
  - override /model: null
  - override /datamodule: openwebtext
  # FusedAdam from apex speeds up the optimizer step a bit, for GPT2-small time
  # per global step (i.e. batch size 512) on 8 A100s goes from 376ms to 368ms.
  # For GPT2-medium time per global goes from 997ms to 972ms.
  - override /optimizer: adamw-apex
  - override /scheduler: linear-warmup
  - override /callbacks: [default, norm-monitor]
  - override /metrics: [perplexity, num-tokens]
  - override /logger: wandb

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

task:
  _target_: src.tasks.seq.SequenceLMModel

seed: 1111

trainer:
  accelerator: gpu
  devices: 8
  num_nodes: 1
  accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${datamodule.batch_size} * ${trainer.num_nodes}}}
  max_steps: 400000
  val_check_interval: ${eval:1000 * ${.accumulate_grad_batches}}
  check_val_every_n_epoch: null  # We don't care about epoch boundary
  precision: 16
  gradient_clip_val: 1.0
  strategy: null

datamodule:
  batch_size: 16  # Per GPU
  batch_size_eval: ${.batch_size}  # Fused dense only support batch size at most 64k
  max_length: 1024
  fault_tolerant: True
  ddp: ${eval:"${trainer.devices} > 1"}

train:
  gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"}
  global_batch_size: 512
  optimizer:
    lr: 6e-4
    weight_decay: 0.1
  optimizer_param_grouping:
    bias_weight_decay: False
    normalization_weight_decay: False
  scheduler:
    num_warmup_steps: ${eval:0.01 * ${trainer.max_steps}}
    num_training_steps: ${trainer.max_steps}
  loss_fn:
    # This is faster and uses less memory than torch.nn.CrossEntropyLoss.
    # It's also more numerically stable if we're using DeepSpeed 16 bits.
57
    _target_: flash_attn.losses.cross_entropy.CrossEntropyLoss
Tri Dao's avatar
Tri Dao committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    inplace_backward: True  # to save memory

eval:
  log_on_step: True  # 1 training epoch takes too long, we want to see metrics per train step

callbacks:
  model_checkpoint:
    monitor: val/loss
    mode: min
    save_top_k: 3
    save_last: True
    every_n_train_steps: 1000
    dirpath: ${work_dir}/checkpoints/${oc.select:name,''}
    filename: step_{step}
    auto_insert_metric_name: False
  model_checkpoint_progress:
    _target_: src.callbacks.model_checkpoint.ModelCheckpointMine
    fault_tolerant: True
    every_n_train_steps: 50000
    save_last: False
    save_top_k: -1  # Save all the checkpoints
    dirpath: ${..model_checkpoint.dirpath}
    filename: progress_step_{step}
    auto_insert_metric_name: False
  early_stopping: null