Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
23c566a6
Unverified
Commit
23c566a6
authored
Oct 25, 2023
by
Andrei Ivanov
Committed by
GitHub
Oct 26, 2023
Browse files
Adding `--num_workers` input parameter to the EEG_GCNN example. (#6467)
parent
760426e4
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
4 deletions
+10
-4
examples/pytorch/eeg-gcnn/main.py
examples/pytorch/eeg-gcnn/main.py
+10
-4
No files found.
examples/pytorch/eeg-gcnn/main.py
View file @
23c566a6
...
...
@@ -37,6 +37,12 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--num_nodes"
,
type
=
int
,
default
=
8
,
help
=
"Number of nodes in the graph"
)
parser
.
add_argument
(
"--num_workers"
,
type
=
int
,
default
=
4
,
help
=
"Number of epochs used to train"
,
)
parser
.
add_argument
(
"--gpu_idx"
,
type
=
int
,
...
...
@@ -97,6 +103,7 @@ if __name__ == "__main__":
_EXPERIMENT_NAME
=
args
.
exp_name
_BATCH_SIZE
=
args
.
batch_size
num_feats
=
args
.
num_feats
num_workers
=
args
.
num_workers
# set up input and targets from files
x
=
_load_memory_mapped_array
(
f
"psd_features_data_X"
)
...
...
@@ -149,7 +156,6 @@ if __name__ == "__main__":
# Dataloader========================================================================================================
# use WeightedRandomSampler to balance the training dataset
NUM_WORKERS
=
4
labels_unique
,
counts
=
np
.
unique
(
y
,
return_counts
=
True
)
...
...
@@ -172,7 +178,7 @@ if __name__ == "__main__":
dataset
=
train_dataset
,
batch_size
=
_BATCH_SIZE
,
sampler
=
weighted_sampler
,
num_workers
=
NUM_WORKERS
,
num_workers
=
num_workers
,
pin_memory
=
True
,
)
...
...
@@ -181,7 +187,7 @@ if __name__ == "__main__":
dataset
=
train_dataset
,
batch_size
=
_BATCH_SIZE
,
shuffle
=
False
,
num_workers
=
NUM_WORKERS
,
num_workers
=
num_workers
,
pin_memory
=
True
,
)
...
...
@@ -194,7 +200,7 @@ if __name__ == "__main__":
dataset
=
test_dataset
,
batch_size
=
_BATCH_SIZE
,
shuffle
=
False
,
num_workers
=
NUM_WORKERS
,
num_workers
=
num_workers
,
pin_memory
=
True
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment