Unverified Commit de6bf28e authored by Ben Graham's avatar Ben Graham Committed by GitHub
Browse files

random seed init

parent 2082f213
......@@ -10,7 +10,7 @@ val_reps=1 # Number of test views, 1 or more
batch_size=32
elastic_deformation=False
import torch, numpy as np, glob, math, torch.utils.data, scipy.ndimage, multiprocessing as mp
import torch, numpy as np, glob, math, torch.utils.data, scipy.ndimage, multiprocessing as mp, time
dimension=3
full_scale=4096 #Input field size
......@@ -82,8 +82,14 @@ def trainMerge(tbl):
labels=torch.cat(labels,0)
return {'x': [locs,feats], 'y': labels.long(), 'id': tbl}
train_data_loader = torch.utils.data.DataLoader(
list(range(len(train))),batch_size=batch_size, collate_fn=trainMerge, num_workers=20, shuffle=True)
list(range(len(train))),
batch_size=batch_size,
collate_fn=trainMerge,
num_workers=20,
shuffle=True,
drop_last=True,
worker_init_fn=lambda x: np.random.seed(x+int(time.time()))
)
valOffsets=[0]
valLabels=[]
......@@ -125,4 +131,10 @@ def valMerge(tbl):
point_ids=torch.cat(point_ids,0)
return {'x': [locs,feats], 'y': labels.long(), 'id': tbl, 'point_ids': point_ids}
val_data_loader = torch.utils.data.DataLoader(
list(range(len(val))),batch_size=batch_size, collate_fn=valMerge, num_workers=20,shuffle=True)
list(range(len(val))),
batch_size=batch_size,
collate_fn=valMerge,
num_workers=20,
shuffle=True,
worker_init_fn=lambda x: np.random.seed(x+int(time.time()))
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment