Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Pytorch-Encoding
Commits
b8d83b0d
Unverified
Commit
b8d83b0d
authored
May 03, 2020
by
Hang Zhang
Committed by
GitHub
May 03, 2020
Browse files
transforms (#272)
parent
f70fa97e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
18 additions
and
15 deletions
+18
-15
encoding/lib/cpu/operator.h
encoding/lib/cpu/operator.h
+1
-1
encoding/transforms/get_transform.py
encoding/transforms/get_transform.py
+10
-10
encoding/utils/lr_scheduler.py
encoding/utils/lr_scheduler.py
+7
-4
No files found.
encoding/lib/cpu/operator.h
View file @
b8d83b0d
...
...
@@ -93,7 +93,7 @@ py::array_t<float> apply_transform(int H, int W, int C, py::array_t<float> img,
auto
ctm_buf
=
ctm
.
request
();
// printf("H: %d, W: %d, C: %d\n", H, W, C);
py
::
array_t
<
float
>
result
{
img_buf
.
size
};
py
::
array_t
<
float
>
result
{
(
unsigned
long
)
img_buf
.
size
};
auto
res_buf
=
result
.
request
();
float
*
img_ptr
=
(
float
*
)
img_buf
.
ptr
;
...
...
encoding/transforms/get_transform.py
View file @
b8d83b0d
...
...
@@ -65,16 +65,16 @@ def get_transform(dataset, base_size=None, crop_size=224, rand_aug=False, etrans
normalize
,
])
elif
dataset
==
'cifar10'
:
transform_train
=
transforms
.
Compose
([
transforms
.
RandomCrop
(
32
,
padding
=
4
),
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.4914
,
0.4822
,
0.4465
),
transform_train
=
Compose
([
RandomCrop
(
32
,
padding
=
4
),
RandomHorizontalFlip
(),
ToTensor
(),
Normalize
((
0.4914
,
0.4822
,
0.4465
),
(
0.2023
,
0.1994
,
0.2010
)),
])
transform_val
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.4914
,
0.4822
,
0.4465
),
transform_val
=
Compose
([
ToTensor
(),
Normalize
((
0.4914
,
0.4822
,
0.4465
),
(
0.2023
,
0.1994
,
0.2010
)),
])
return
transform_train
,
transform_val
...
...
encoding/utils/lr_scheduler.py
View file @
b8d83b0d
...
...
@@ -29,8 +29,10 @@ class LR_Scheduler(object):
iters_per_epoch: number of iterations per epoch
"""
def
__init__
(
self
,
mode
,
base_lr
,
num_epochs
,
iters_per_epoch
=
0
,
lr_step
=
0
,
warmup_epochs
=
0
):
lr_step
=
0
,
warmup_epochs
=
0
,
quiet
=
False
):
self
.
mode
=
mode
self
.
quiet
=
quiet
if
not
quiet
:
print
(
'Using {} LR scheduler with warm-up epochs of {}!'
.
format
(
self
.
mode
,
warmup_epochs
))
if
mode
==
'step'
:
assert
lr_step
...
...
@@ -57,6 +59,7 @@ class LR_Scheduler(object):
else
:
raise
NotImplemented
if
epoch
>
self
.
epoch
and
(
epoch
==
0
or
best_pred
>
0.0
):
if
not
self
.
quiet
:
print
(
'
\n
=>Epoch %i, learning rate = %.4f,
\
previous best = %.4f'
%
(
epoch
,
lr
,
best_pred
))
self
.
epoch
=
epoch
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment