Unverified Commit b8d83b0d authored by Hang Zhang's avatar Hang Zhang Committed by GitHub
Browse files

transforms (#272)

parent f70fa97e
...@@ -93,7 +93,7 @@ py::array_t<float> apply_transform(int H, int W, int C, py::array_t<float> img, ...@@ -93,7 +93,7 @@ py::array_t<float> apply_transform(int H, int W, int C, py::array_t<float> img,
auto ctm_buf = ctm.request(); auto ctm_buf = ctm.request();
// printf("H: %d, W: %d, C: %d\n", H, W, C); // printf("H: %d, W: %d, C: %d\n", H, W, C);
py::array_t<float> result{img_buf.size}; py::array_t<float> result{(unsigned long)img_buf.size};
auto res_buf = result.request(); auto res_buf = result.request();
float *img_ptr = (float *)img_buf.ptr; float *img_ptr = (float *)img_buf.ptr;
......
...@@ -65,16 +65,16 @@ def get_transform(dataset, base_size=None, crop_size=224, rand_aug=False, etrans ...@@ -65,16 +65,16 @@ def get_transform(dataset, base_size=None, crop_size=224, rand_aug=False, etrans
normalize, normalize,
]) ])
elif dataset == 'cifar10': elif dataset == 'cifar10':
transform_train = transforms.Compose([ transform_train = Compose([
transforms.RandomCrop(32, padding=4), RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(), RandomHorizontalFlip(),
transforms.ToTensor(), ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)), (0.2023, 0.1994, 0.2010)),
]) ])
transform_val = transforms.Compose([ transform_val = Compose([
transforms.ToTensor(), ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)), (0.2023, 0.1994, 0.2010)),
]) ])
return transform_train, transform_val return transform_train, transform_val
......
...@@ -29,8 +29,10 @@ class LR_Scheduler(object): ...@@ -29,8 +29,10 @@ class LR_Scheduler(object):
iters_per_epoch: number of iterations per epoch iters_per_epoch: number of iterations per epoch
""" """
def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0, def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0,
lr_step=0, warmup_epochs=0): lr_step=0, warmup_epochs=0, quiet=False):
self.mode = mode self.mode = mode
self.quiet = quiet
if not quiet:
print('Using {} LR scheduler with warm-up epochs of {}!'.format(self.mode, warmup_epochs)) print('Using {} LR scheduler with warm-up epochs of {}!'.format(self.mode, warmup_epochs))
if mode == 'step': if mode == 'step':
assert lr_step assert lr_step
...@@ -57,6 +59,7 @@ class LR_Scheduler(object): ...@@ -57,6 +59,7 @@ class LR_Scheduler(object):
else: else:
raise NotImplemented raise NotImplemented
if epoch > self.epoch and (epoch == 0 or best_pred > 0.0): if epoch > self.epoch and (epoch == 0 or best_pred > 0.0):
if not self.quiet:
print('\n=>Epoch %i, learning rate = %.4f, \ print('\n=>Epoch %i, learning rate = %.4f, \
previous best = %.4f' % (epoch, lr, best_pred)) previous best = %.4f' % (epoch, lr, best_pred))
self.epoch = epoch self.epoch = epoch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment