Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
3b32c401
Commit
3b32c401
authored
Apr 26, 2019
by
Michael Carilli
Browse files
Fixed bounds checking
parent
2c63ba91
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
20 deletions
+21
-20
apex/optimizers/fused_sgd.py
apex/optimizers/fused_sgd.py
+4
-3
csrc/multi_tensor_sgd_kernel.cu
csrc/multi_tensor_sgd_kernel.cu
+17
-17
No files found.
apex/optimizers/fused_sgd.py
View file @
3b32c401
...
...
@@ -63,7 +63,7 @@ class FusedSGD(Optimizer):
weight_decay
=
weight_decay
,
nesterov
=
nesterov
)
if
nesterov
and
(
momentum
<=
0
or
dampening
!=
0
):
raise
ValueError
(
"Nesterov momentum requires a momentum and zero dampening"
)
super
(
SGD
,
self
).
__init__
(
params
,
defaults
)
super
(
Fused
SGD
,
self
).
__init__
(
params
,
defaults
)
self
.
wd_after_momentum
=
wd_after_momentum
...
...
@@ -80,8 +80,9 @@ class FusedSGD(Optimizer):
for
group
in
self
.
param_groups
:
group
.
setdefault
(
'nesterov'
,
False
)
def
get_momentums
(
params
):
def
get_momentums
(
self
,
params
):
momentums
=
[]
first_run
=
True
for
p
in
params
:
param_state
=
self
.
state
[
p
]
# torch.optim.SGD initializes momentum in the main loop, we have
...
...
@@ -153,7 +154,7 @@ class FusedSGD(Optimizer):
launch_sets
=
[[
fp16_grads
,
fp16_params
,
fp16_momentums
],
[
fp32_grads
,
fp32_params
,
fp32_momentums
]]
for
launch_set
,
first_run
in
zip
(
launch_sets
,
first_runs
):
for
s
,
(
launch_set
,
first_run
)
in
enumerate
(
zip
(
launch_sets
,
first_runs
)
)
:
assert
len
(
launch_set
[
0
])
==
len
(
launch_set
[
1
])
assert
len
(
launch_set
[
0
])
==
len
(
launch_set
[
2
])
if
len
(
launch_set
[
0
])
>
0
:
...
...
csrc/multi_tensor_sgd_kernel.cu
View file @
3b32c401
...
...
@@ -57,7 +57,8 @@ struct SGDFunctor
mom_in
+=
chunk_idx
*
chunk_size
;
at
::
Half
*
model_weights_out
=
nullptr
;
if
(
N
==
4
)
{
if
(
N
==
4
)
{
model_weights_out
=
(
at
::
Half
*
)
tl
.
addresses
[
3
][
tensor_loc
];
model_weights_out
+=
chunk_idx
*
chunk_size
;
}
...
...
@@ -80,9 +81,11 @@ struct SGDFunctor
incoming_moms
[
ii
]
=
0
;
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
incoming_grads
[
ii
]
=
static_cast
<
float
>
(
grad_in
[
i
]);
incoming_weights
[
ii
]
=
static_cast
<
float
>
(
weight_in
[
i
]);
incoming_moms
[
ii
]
=
static_cast
<
float
>
(
mom_in
[
i
]);
}
}
// note for clarification to future michael:
...
...
@@ -94,43 +97,40 @@ struct SGDFunctor
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
if
(
i
<
n
&&
i
<
chunk_size
)
{
// apply weight decay before momentum if necessary
if
(
wd
!=
0.
f
&&
!
wd_after_momentum
)
{
if
(
wd
!=
0.
f
&&
!
wd_after_momentum
)
incoming_grads
[
ii
]
+=
wd
*
incoming_weights
[
ii
];
}
if
(
momentum
!=
0.
f
)
{
if
(
!
first_run
)
{
if
(
momentum
!=
0.
f
)
{
if
(
!
first_run
)
incoming_moms
[
ii
]
=
incoming_moms
[
ii
]
*
momentum
+
(
1.
f
-
dampening
)
*
incoming_grads
[
ii
];
}
else
{
else
// initialize momentume to current incoming grads
incoming_moms
[
ii
]
=
incoming_grads
[
ii
];
}
if
(
nesterov
)
{
if
(
nesterov
)
incoming_grads
[
ii
]
+=
momentum
*
incoming_moms
[
ii
];
}
else
{
else
incoming_grads
[
ii
]
=
incoming_moms
[
ii
];
}
}
// Apply WD after momentum if desired
if
(
wd
!=
0.
f
&&
wd_after_momentum
)
{
if
(
wd
!=
0.
f
&&
wd_after_momentum
)
incoming_grads
[
ii
]
+=
wd
*
incoming_weights
[
ii
];
}
// adjust the weight and write out
weight_in
[
i
]
+=
(
-
lr
*
incoming_grads
[
ii
]);
// if necessary, write out an fp16 copy of the weights
if
(
N
==
4
)
{
if
(
N
==
4
)
model_weights_out
[
i
]
=
static_cast
<
at
::
Half
>
(
weight_in
[
i
]);
}
// also write out the new momentum
if
(
momentum
!=
0.
f
)
{
if
(
momentum
!=
0.
f
)
mom_in
[
i
]
=
incoming_moms
[
ii
];
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment