Commit e6b5a952 authored by Andrea Paris's avatar Andrea Paris
Browse files

first patch

parent 5da00de0
...@@ -138,7 +138,7 @@ def validate_model( ...@@ -138,7 +138,7 @@ def validate_model(
for metric in metrics_fns: for metric in metrics_fns:
metric_buff = metrics[metric] metric_buff = metrics[metric]
metric_fn = metrics_fns[metric] metric_fn = metrics_fns[metric]
metric_buff[idx] = metric_fn(prd, tar.unsqueeze(-3), mask) metric_buff[idx] = metric_fn(prd, tar, mask)
tar = (tar * mask).squeeze() tar = (tar * mask).squeeze()
prd = (prd * mask).squeeze() prd = (prd * mask).squeeze()
...@@ -257,7 +257,7 @@ def train_model( ...@@ -257,7 +257,7 @@ def train_model(
# prepare metrics buffer for accumulation of validation metrics # prepare metrics buffer for accumulation of validation metrics
valid_metrics = {} valid_metrics = {}
for metric in metrics_fns: for metric in metrics_fns:
valid_metrics[metric] = torch.zeros(1, dtype=torch.float32, device=device) valid_metrics[metric] = torch.zeros(2, dtype=torch.float32, device=device)
model.eval() model.eval()
...@@ -287,6 +287,7 @@ def train_model( ...@@ -287,6 +287,7 @@ def train_model(
metric_buff = valid_metrics[metric] metric_buff = valid_metrics[metric]
metric_fn = metrics_fns[metric] metric_fn = metrics_fns[metric]
metric_buff[0] += metric_fn(prd, tar, mask) * inp.size(0) metric_buff[0] += metric_fn(prd, tar, mask) * inp.size(0)
metric_buff[1] += inp.size(0)
if dist.is_initialized(): if dist.is_initialized():
dist.all_reduce(valid_loss) dist.all_reduce(valid_loss)
...@@ -294,8 +295,9 @@ def train_model( ...@@ -294,8 +295,9 @@ def train_model(
dist.all_reduce(valid_metrics[metric]) dist.all_reduce(valid_metrics[metric])
valid_loss = (valid_loss[0] / valid_loss[1]).item() valid_loss = (valid_loss[0] / valid_loss[1]).item()
for metric in valid_metrics: for metric in valid_metrics:
valid_metrics[metric] = (valid_metrics[metric][0] / valid_loss[1]).item() valid_metrics[metric] = (valid_metrics[metric][0] / valid_metrics[metric][1]).item()
if scheduler is not None: if scheduler is not None:
scheduler.step(valid_loss) scheduler.step(valid_loss)
...@@ -435,16 +437,16 @@ def main( ...@@ -435,16 +437,16 @@ def main(
# specify which models to train here # specify which models to train here
models = [ models = [
"transformer_sc2_layers4_e128", #"transformer_sc2_layers4_e128",
"s2transformer_sc2_layers4_e128", #"s2transformer_sc2_layers4_e128",
"ntransformer_sc2_layers4_e128", #"ntransformer_sc2_layers4_e128",
"s2ntransformer_sc2_layers4_e128", "s2ntransformer_sc2_layers4_e128",
"segformer_sc2_layers4_e128", #"segformer_sc2_layers4_e128",
"s2segformer_sc2_layers4_e128", #"s2segformer_sc2_layers4_e128",
"nsegformer_sc2_layers4_e128", #"nsegformer_sc2_layers4_e128",
"s2nsegformer_sc2_layers4_e128", #"s2nsegformer_sc2_layers4_e128",
"sfno_sc2_layers4_e32", #"sfno_sc2_layers4_e32",
"lsno_sc2_layers4_e32", #"lsno_sc2_layers4_e32",
] ]
models = {k: baseline_models[k] for k in models} models = {k: baseline_models[k] for k in models}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment