Commit e6b5a952 authored by Andrea Paris's avatar Andrea Paris
Browse files

first patch

parent 5da00de0
......@@ -138,7 +138,7 @@ def validate_model(
for metric in metrics_fns:
metric_buff = metrics[metric]
metric_fn = metrics_fns[metric]
metric_buff[idx] = metric_fn(prd, tar.unsqueeze(-3), mask)
metric_buff[idx] = metric_fn(prd, tar, mask)
tar = (tar * mask).squeeze()
prd = (prd * mask).squeeze()
......@@ -257,7 +257,7 @@ def train_model(
# prepare metrics buffer for accumulation of validation metrics
valid_metrics = {}
for metric in metrics_fns:
valid_metrics[metric] = torch.zeros(1, dtype=torch.float32, device=device)
valid_metrics[metric] = torch.zeros(2, dtype=torch.float32, device=device)
model.eval()
......@@ -287,6 +287,7 @@ def train_model(
metric_buff = valid_metrics[metric]
metric_fn = metrics_fns[metric]
metric_buff[0] += metric_fn(prd, tar, mask) * inp.size(0)
metric_buff[1] += inp.size(0)
if dist.is_initialized():
dist.all_reduce(valid_loss)
......@@ -294,8 +295,9 @@ def train_model(
dist.all_reduce(valid_metrics[metric])
valid_loss = (valid_loss[0] / valid_loss[1]).item()
for metric in valid_metrics:
valid_metrics[metric] = (valid_metrics[metric][0] / valid_loss[1]).item()
valid_metrics[metric] = (valid_metrics[metric][0] / valid_metrics[metric][1]).item()
if scheduler is not None:
scheduler.step(valid_loss)
......@@ -435,16 +437,16 @@ def main(
# specify which models to train here
models = [
"transformer_sc2_layers4_e128",
"s2transformer_sc2_layers4_e128",
"ntransformer_sc2_layers4_e128",
#"transformer_sc2_layers4_e128",
#"s2transformer_sc2_layers4_e128",
#"ntransformer_sc2_layers4_e128",
"s2ntransformer_sc2_layers4_e128",
"segformer_sc2_layers4_e128",
"s2segformer_sc2_layers4_e128",
"nsegformer_sc2_layers4_e128",
"s2nsegformer_sc2_layers4_e128",
"sfno_sc2_layers4_e32",
"lsno_sc2_layers4_e32",
#"segformer_sc2_layers4_e128",
#"s2segformer_sc2_layers4_e128",
#"nsegformer_sc2_layers4_e128",
#"s2nsegformer_sc2_layers4_e128",
#"sfno_sc2_layers4_e32",
#"lsno_sc2_layers4_e32",
]
models = {k: baseline_models[k] for k in models}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment