AFNO.yaml 4.73 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
### base config ###
full_field: &FULL_FIELD
  loss: 'l2'
  lr: 1E-3
  scheduler: 'ReduceLROnPlateau'
  num_data_workers: 4
  dt: 1 # how many timesteps ahead the model will predict
  n_history: 0 #how many previous timesteps to consider
  prediction_type: 'iterative'
  prediction_length: 41 #applicable only if prediction_type == 'iterative'
  n_initial_conditions: 5 #applicable only if prediction_type == 'iterative'
  ics_type: "default"
  save_raw_forecasts: !!bool True
  save_channel: !!bool False
  masked_acc: !!bool False
  maskpath: None
  perturb: !!bool False
  add_grid: !!bool False
  N_grid_channels: 0
  gridtype: 'sinusoidal' #options 'sinusoidal' or 'linear'
  roll: !!bool False
  max_epochs: 50
  batch_size: 64

  #afno hyperparams
  num_blocks: 8
  nettype: 'afno'
  patch_size: 8
  width: 56
  modes: 32
  #options default, residual
  target: 'default' 
  in_channels: [0,1]
  out_channels: [0,1] #must be same as in_channels if prediction_type == 'iterative'
  normalization: 'zscore' #options zscore (minmax not supported) 
  train_data_path: '/pscratch/sd/j/jpathak/wind/train'
  valid_data_path: '/pscratch/sd/j/jpathak/wind/test'
  inf_data_path: '/pscratch/sd/j/jpathak/wind/out_of_sample' # test set path for inference
  exp_dir: '/pscratch/sd/j/jpathak/ERA5_expts_gtc/wind'
  time_means_path:   '/pscratch/sd/j/jpathak/wind/time_means.npy'
  global_means_path: '/pscratch/sd/j/jpathak/wind/global_means.npy'
  global_stds_path:  '/pscratch/sd/j/jpathak/wind/global_stds.npy'

  orography: !!bool False
  orography_path: None

  log_to_screen: !!bool True
  log_to_wandb: !!bool True
  save_checkpoint: !!bool True

  enable_nhwc: !!bool False
  optimizer_type: 'FusedAdam'
  crop_size_x: None
  crop_size_y: None

  two_step_training: !!bool False
  plot_animations: !!bool False

  add_noise: !!bool False
  noise_std: 0

afno_backbone: &backbone
  <<: *FULL_FIELD
  log_to_wandb: !!bool True
  lr: 5E-4
  batch_size: 64
  max_epochs: 150
  scheduler: 'CosineAnnealingLR'
  in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
  out_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
  orography: !!bool False
  orography_path: None 
  exp_dir: '/pscratch/sd/s/shas1693/results/era5_wind'
  # train_data_path: '/pscratch/sd/s/shas1693/data/era5/train'
  # valid_data_path: '/pscratch/sd/s/shas1693/data/era5/test'
  # inf_data_path:   '/pscratch/sd/s/shas1693/data/era5/out_of_sample'
  # time_means_path:   '/pscratch/sd/s/shas1693/data/era5/time_means.npy'
  # global_means_path: '/pscratch/sd/s/shas1693/data/era5/global_means.npy'
  # global_stds_path:  '/pscratch/sd/s/shas1693/data/era5/global_stds.npy'
    # ==== 数据路径(关键)====
  train_data_path: '/workspace/FourCastNet/data/train'
  valid_data_path: '/workspace/FourCastNet/data/train'   # 先用同一份也可以跑
  inf_data_path:   '/workspace/FourCastNet/data/train'

  # ==== 统计量(如果你现在还没有,先这样处理)====
  time_means_path:   '/workspace/FourCastNet/data/time_means.npy'
  global_means_path: '/workspace/FourCastNet/data/global_means.npy'
  global_stds_path:  '/workspace/FourCastNet/data/global_stds.npy'


afno_backbone_orography: &backbone_orography 
  <<: *backbone
  orography: !!bool True
  orography_path: '/pscratch/sd/s/shas1693/data/era5/static/orography.h5'

afno_backbone_finetune: 
  <<: *backbone
  lr: 1E-4
  batch_size: 64
  log_to_wandb: !!bool True
  max_epochs: 50
  pretrained: !!bool True
  two_step_training: !!bool True
  pretrained_ckpt_path: '/pscratch/sd/s/shas1693/results/era5_wind/afno_backbone/0/training_checkpoints/best_ckpt.tar'

perturbations:
  <<: *backbone
  lr: 1E-4
  batch_size: 64
  max_epochs: 50
  pretrained: !!bool True
  two_step_training: !!bool True
  pretrained_ckpt_path: '/pscratch/sd/j/jpathak/ERA5_expts_gtc/wind/afno_20ch_bs_64_lr5em4_blk_8_patch_8_cosine_sched/1/training_checkpoints/best_ckpt.tar'
  prediction_length: 24
  ics_type: "datetime"
  n_perturbations: 100 
  save_channel: !bool True
  save_idx: 4
  save_raw_forecasts: !!bool False
  date_strings: ["2018-01-01 00:00:00"] 
  inference_file_tag: " "
  valid_data_path: "/pscratch/sd/j/jpathak/ "
  perturb: !!bool True
  n_level: 0.3

### PRECIP ###
precip: &precip
  <<: *backbone
  in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
  out_channels: [0]
  nettype: 'afno'
  nettype_wind: 'afno'
  log_to_wandb: !!bool True
  lr: 2.5E-4
  batch_size: 64
  max_epochs: 25
  precip: '/pscratch/sd/p/pharring/ERA5/precip/total_precipitation'
  time_means_path_tp: '/pscratch/sd/p/pharring/ERA5/precip/total_precipitation/time_means.npy'
  model_wind_path: '/pscratch/sd/s/shas1693/results/era5_wind/afno_backbone_finetune/0/training_checkpoints/best_ckpt.tar'
  precip_eps: !!float 1e-5