gpt3-2.7B-flash-8k.yaml 368 Bytes
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# @package _global_
defaults:
  - /experiment/pile/gpt2xl-flash-8k.yaml

model:
  config:
    n_embd: 2560
    n_head: 32
    n_layer: 32
    initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"}
    mlp_checkpoint_lvl: 0

datamodule:
  batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"}

train:
  optimizer:
    lr: 1.6e-4