"model/models/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "0e886595bf3d4ee33737f4b30154210b0df2d2df"
Commit 628441b7 authored by ThangVu's avatar ThangVu
Browse files

caffe2 preprocess in group norm unit test

parent 8500c14e
...@@ -197,8 +197,7 @@ def main_worker(gpu, ngpus_per_node, args): ...@@ -197,8 +197,7 @@ def main_worker(gpu, ngpus_per_node, args):
traindir = os.path.join(args.data, 'train') traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val') valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]) std=[1/255, 1/255, 1/255])
train_dataset = datasets.ImageFolder( train_dataset = datasets.ImageFolder(
traindir, traindir,
transforms.Compose([ transforms.Compose([
...@@ -321,6 +320,7 @@ def validate(val_loader, model, criterion, args): ...@@ -321,6 +320,7 @@ def validate(val_loader, model, criterion, args):
if args.gpu is not None: if args.gpu is not None:
input = input.cuda(args.gpu, non_blocking=True) input = input.cuda(args.gpu, non_blocking=True)
target = target.cuda(args.gpu, non_blocking=True) target = target.cuda(args.gpu, non_blocking=True)
input = torch.cat([input[:, 2:3, :, :], input[:, 1:2, :, :], input[:, 0:1, :, :]], dim=1)
# compute output # compute output
output = model(input) output = model(input)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment