"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "cfc99adf0f2e45afbddc117671e4faa59ca83ae2"
Unverified Commit 23c566a6 authored by Andrei Ivanov's avatar Andrei Ivanov Committed by GitHub
Browse files

Adding `--num_workers` input parameter to the EEG_GCNN example. (#6467)

parent 760426e4
...@@ -37,6 +37,12 @@ if __name__ == "__main__": ...@@ -37,6 +37,12 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--num_nodes", type=int, default=8, help="Number of nodes in the graph" "--num_nodes", type=int, default=8, help="Number of nodes in the graph"
) )
parser.add_argument(
"--num_workers",
type=int,
default=4,
help="Number of epochs used to train",
)
parser.add_argument( parser.add_argument(
"--gpu_idx", "--gpu_idx",
type=int, type=int,
...@@ -97,6 +103,7 @@ if __name__ == "__main__": ...@@ -97,6 +103,7 @@ if __name__ == "__main__":
_EXPERIMENT_NAME = args.exp_name _EXPERIMENT_NAME = args.exp_name
_BATCH_SIZE = args.batch_size _BATCH_SIZE = args.batch_size
num_feats = args.num_feats num_feats = args.num_feats
num_workers = args.num_workers
# set up input and targets from files # set up input and targets from files
x = _load_memory_mapped_array(f"psd_features_data_X") x = _load_memory_mapped_array(f"psd_features_data_X")
...@@ -149,7 +156,6 @@ if __name__ == "__main__": ...@@ -149,7 +156,6 @@ if __name__ == "__main__":
# Dataloader======================================================================================================== # Dataloader========================================================================================================
# use WeightedRandomSampler to balance the training dataset # use WeightedRandomSampler to balance the training dataset
NUM_WORKERS = 4
labels_unique, counts = np.unique(y, return_counts=True) labels_unique, counts = np.unique(y, return_counts=True)
...@@ -172,7 +178,7 @@ if __name__ == "__main__": ...@@ -172,7 +178,7 @@ if __name__ == "__main__":
dataset=train_dataset, dataset=train_dataset,
batch_size=_BATCH_SIZE, batch_size=_BATCH_SIZE,
sampler=weighted_sampler, sampler=weighted_sampler,
num_workers=NUM_WORKERS, num_workers=num_workers,
pin_memory=True, pin_memory=True,
) )
...@@ -181,7 +187,7 @@ if __name__ == "__main__": ...@@ -181,7 +187,7 @@ if __name__ == "__main__":
dataset=train_dataset, dataset=train_dataset,
batch_size=_BATCH_SIZE, batch_size=_BATCH_SIZE,
shuffle=False, shuffle=False,
num_workers=NUM_WORKERS, num_workers=num_workers,
pin_memory=True, pin_memory=True,
) )
...@@ -194,7 +200,7 @@ if __name__ == "__main__": ...@@ -194,7 +200,7 @@ if __name__ == "__main__":
dataset=test_dataset, dataset=test_dataset,
batch_size=_BATCH_SIZE, batch_size=_BATCH_SIZE,
shuffle=False, shuffle=False,
num_workers=NUM_WORKERS, num_workers=num_workers,
pin_memory=True, pin_memory=True,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment