Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e4f555f2
Unverified
Commit
e4f555f2
authored
Jun 20, 2022
by
ver217
Committed by
GitHub
Jun 20, 2022
Browse files
[optim] refactor fused sgd (#1134)
parent
d2690264
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
135 deletions
+31
-135
colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu
...ssalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu
+14
-57
colossalai/nn/optimizer/fused_sgd.py
colossalai/nn/optimizer/fused_sgd.py
+17
-78
No files found.
colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu
View file @
e4f555f2
...
...
@@ -28,10 +28,10 @@
* first run : necessary for proper momentum handling & init
* wd_after_momentum : apply weight decay _after_ momentum instead of before
**/
template
<
int
N
,
typename
T_grad
,
typename
T_weight
>
template
<
typename
T_grad
,
typename
T_weight
>
struct
SGDFunctor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
N
>
&
tl
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
3
>
&
tl
,
float
wd
,
float
momentum
,
float
dampening
,
float
lr
,
bool
nesterov
,
bool
first_run
,
bool
wd_after_momentum
,
float
scale
)
{
// Early exit if we don't need to do anything
...
...
@@ -50,12 +50,6 @@ struct SGDFunctor {
T_weight
*
mom_in
=
(
T_weight
*
)
tl
.
addresses
[
2
][
tensor_loc
];
mom_in
+=
chunk_idx
*
chunk_size
;
at
::
Half
*
model_weights_out
=
nullptr
;
if
(
N
==
4
)
{
model_weights_out
=
(
at
::
Half
*
)
tl
.
addresses
[
3
][
tensor_loc
];
model_weights_out
+=
chunk_idx
*
chunk_size
;
}
n
-=
chunk_idx
*
chunk_size
;
// Non-divergent exit condition for the __syncthreads
...
...
@@ -110,10 +104,6 @@ struct SGDFunctor {
// adjust the weight and write out
weight_in
[
i
]
+=
(
-
lr
*
incoming_grads
[
ii
]);
// if necessary, write out an fp16 copy of the weights
if
(
N
==
4
)
model_weights_out
[
i
]
=
static_cast
<
at
::
Half
>
(
weight_in
[
i
]);
// also write out the new momentum
if
(
momentum
!=
0.
f
)
mom_in
[
i
]
=
incoming_moms
[
ii
];
}
...
...
@@ -131,20 +121,14 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
auto
grad_type
=
tensor_lists
[
0
][
0
].
scalar_type
();
auto
weight_type
=
tensor_lists
[
1
][
0
].
scalar_type
();
if
(
num_tensors
==
4
)
for
(
int
i
=
0
;
i
<
tensor_lists
[
3
].
size
();
i
++
)
TORCH_CHECK
(
tensor_lists
[
3
][
i
].
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Additional output tensors should always be fp16."
);
TORCH_CHECK
(
noop_flag
.
device
()
==
tensor_lists
[
0
][
0
].
device
(),
"expected noop flag to be on the same device as tensors"
);
// We have 3 possibilities to handle here, in terms of
// grad_type, param_type, momentum_type, requires_fp16_copy
// 1. fp16, fp16, fp16, No
// 2. fp32, fp32, fp32, No
// 3. fp16, fp32, fp32, Yes
// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
// grad_type, param_type, momentum_type
// 1. fp16, fp16, fp16
// 2. fp32, fp32, fp32
// 3. fp16, fp32, fp32
// It's easier to hardcode these possibilities than to use
// switches etc. to handle the cross-product of cases where
// we don't want the majority of them.
...
...
@@ -153,49 +137,22 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
if
(
grad_type
==
at
::
ScalarType
::
Half
&&
weight_type
==
at
::
ScalarType
::
Half
&&
num_tensors
==
3
)
{
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
SGDFunctor
<
3
,
at
::
Half
,
at
::
Half
>
(),
wd
,
momentum
,
SGDFunctor
<
at
::
Half
,
at
::
Half
>
(),
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
}
// Case 2. fp16, fp32, fp32, No
// else if (grad_type == at::ScalarType::Half &&
// weight_type == at::ScalarType::Float &&
// num_tensors == 3) {
// multi_tensor_apply<3>(
// BLOCK_SIZE,
// chunk_size,
// noop_flag,
// tensor_lists,
// SGDFunctor<3, at::Half, float>(),
// wd,
// momentum,
// dampening,
// lr,
// nesterov,
// first_run,
// wd_after_momentum);
// }
// Case 2. fp32, fp32, fp32, No
// Case 2. fp32, fp32, fp32
else
if
(
grad_type
==
at
::
ScalarType
::
Float
&&
weight_type
==
at
::
ScalarType
::
Float
&&
num_tensors
==
3
)
{
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
SGDFunctor
<
3
,
float
,
float
>
(),
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
SGDFunctor
<
float
,
float
>
(),
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
}
// Case 3. fp16, fp32, fp32
, Yes
// Case 3. fp16, fp32, fp32
else
if
(
grad_type
==
at
::
ScalarType
::
Half
&&
weight_type
==
at
::
ScalarType
::
Float
&&
num_tensors
==
4
)
{
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
SGDFunctor
<
4
,
at
::
Half
,
float
>
(),
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
}
// Case 4. fp32, fp32, fp32, Yes
else
if
(
grad_type
==
at
::
ScalarType
::
Float
&&
weight_type
==
at
::
ScalarType
::
Float
&&
num_tensors
==
4
)
{
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
SGDFunctor
<
4
,
float
,
float
>
(),
wd
,
momentum
,
weight_type
==
at
::
ScalarType
::
Float
&&
num_tensors
==
3
)
{
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
SGDFunctor
<
at
::
Half
,
float
>
(),
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
}
else
{
...
...
colossalai/nn/optimizer/fused_sgd.py
View file @
e4f555f2
...
...
@@ -64,9 +64,7 @@ class FusedSGD(Optimizer):
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
wd_after_momentum
=
False
,
materialize_master_grads
=
True
,
set_grad_none
=
False
):
wd_after_momentum
=
False
):
if
lr
is
not
required
and
lr
<
0.0
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
momentum
<
0.0
:
...
...
@@ -80,10 +78,6 @@ class FusedSGD(Optimizer):
super
(
FusedSGD
,
self
).
__init__
(
params
,
defaults
)
self
.
wd_after_momentum
=
wd_after_momentum
self
.
materialize_master_grads
=
materialize_master_grads
self
.
most_recent_scale
=
1.0
self
.
scale_set_by_backward
=
False
self
.
set_grad_none
=
set_grad_none
if
multi_tensor_applier
.
available
:
import
colossal_C
...
...
@@ -100,14 +94,6 @@ class FusedSGD(Optimizer):
for
group
in
self
.
param_groups
:
group
.
setdefault
(
'nesterov'
,
False
)
def
zero_grad
(
self
):
if
self
.
set_grad_none
:
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
p
.
grad
=
None
else
:
super
(
FusedSGD
,
self
).
zero_grad
()
def
get_momentums
(
self
,
params
):
momentums
=
[]
first_run
=
True
...
...
@@ -136,74 +122,27 @@ class FusedSGD(Optimizer):
if
closure
is
not
None
:
loss
=
closure
()
explicit_master_params
=
(
hasattr
(
self
,
"_amp_stash"
)
and
hasattr
(
self
.
_amp_stash
,
"fp32_from_fp16_groups"
))
for
gid
,
group
in
enumerate
(
self
.
param_groups
):
for
group
in
self
.
param_groups
:
weight_decay
=
group
[
'weight_decay'
]
momentum
=
group
[
'momentum'
]
dampening
=
group
[
'dampening'
]
nesterov
=
group
[
'nesterov'
]
# For each group, there are 3 possible combinations we need to consider:
# grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy
# 1. fp16, fp16, fp16, No
# 2. fp32, fp32, fp32, No
# 3. fp16, fp32, fp32, Yes
first_runs
=
[
True
,
True
]
# I think a bit of code divergence in exchange for naming clarity is worthwhile
if
explicit_master_params
:
stash
=
self
.
_amp_stash
fp32_params
=
[
p
for
p
in
stash
.
fp32_from_fp32_groups
[
gid
]
if
p
.
grad
is
not
None
]
fp32_grads
=
[
p
.
grad
for
p
in
stash
.
fp32_from_fp32_groups
[
gid
]
if
p
.
grad
is
not
None
]
fp32_momentums
,
first_runs
[
1
]
=
self
.
get_momentums
(
fp32_params
)
if
self
.
materialize_master_grads
:
fp16_model_params
=
[
p
for
i
,
p
in
enumerate
(
stash
.
fp16_groups
[
gid
])
if
stash
.
fp32_from_fp16_groups
[
gid
][
i
].
grad
is
not
None
]
fp32_from_fp16_grads
=
[
p
.
grad
for
p
in
stash
.
fp32_from_fp16_groups
[
gid
]
if
p
.
grad
is
not
None
]
fp32_from_fp16_params
=
[
p
for
p
in
stash
.
fp32_from_fp16_groups
[
gid
]
if
p
.
grad
is
not
None
]
fp32_from_fp16_momentums
,
first_runs
[
0
]
=
self
.
get_momentums
(
fp32_from_fp16_params
)
fp16_set
=
[
fp32_from_fp16_grads
,
fp32_from_fp16_params
,
fp32_from_fp16_momentums
,
fp16_model_params
]
else
:
fp16_model_params
=
[
p
for
p
in
stash
.
fp16_groups
[
gid
]
if
p
.
grad
is
not
None
]
fp16_model_grads
=
[
p
.
grad
for
p
in
stash
.
fp16_groups
[
gid
]
if
p
.
grad
is
not
None
]
fp32_from_fp16_params
=
[
p
for
i
,
p
in
enumerate
(
stash
.
fp32_from_fp16_groups
[
gid
])
if
stash
.
fp16_groups
[
gid
][
i
].
grad
is
not
None
]
fp32_from_fp16_momentums
,
first_runs
[
0
]
=
self
.
get_momentums
(
fp32_from_fp16_params
)
fp16_set
=
[
fp16_model_grads
,
fp32_from_fp16_params
,
fp32_from_fp16_momentums
,
fp16_model_params
]
launch_sets
=
[
fp16_set
,
[
fp32_grads
,
fp32_params
,
fp32_momentums
]]
else
:
fp16_params
=
[
p
for
p
in
group
[
'params'
]
if
(
p
.
dtype
==
torch
.
float16
and
p
.
grad
is
not
None
)]
fp16_grads
=
[
p
.
grad
for
p
in
group
[
'params'
]
if
(
p
.
dtype
==
torch
.
float16
and
p
.
grad
is
not
None
)]
fp16_momentums
,
first_runs
[
0
]
=
self
.
get_momentums
(
fp16_params
)
fp32_params
=
[
p
for
p
in
group
[
'params'
]
if
(
p
.
dtype
==
torch
.
float32
and
p
.
grad
is
not
None
)]
fp32_grads
=
[
p
.
grad
for
p
in
group
[
'params'
]
if
(
p
.
dtype
==
torch
.
float32
and
p
.
grad
is
not
None
)]
fp32_momentums
,
first_runs
[
1
]
=
self
.
get_momentums
(
fp32_params
)
launch_sets
=
[[
fp16_grads
,
fp16_params
,
fp16_momentums
],
[
fp32_grads
,
fp32_params
,
fp32_momentums
]]
for
s
,
(
launch_set
,
first_run
)
in
enumerate
(
zip
(
launch_sets
,
first_runs
)):
assert
len
(
launch_set
[
0
])
==
len
(
launch_set
[
1
])
assert
len
(
launch_set
[
0
])
==
len
(
launch_set
[
2
])
if
len
(
launch_set
[
0
])
>
0
:
multi_tensor_applier
(
self
.
multi_tensor_sgd
,
self
.
_dummy_overflow_buf
,
launch_set
,
weight_decay
,
momentum
,
dampening
,
group
[
'lr'
],
nesterov
,
first_run
,
self
.
wd_after_momentum
,
1.0
/
self
.
most_recent_scale
)
self
.
most_recent_scale
=
1.0
self
.
scale_set_by_backward
=
False
# grad_type, param_to_update_type, momentum_type
# 1. fp16, fp16, fp16
# 2. fp32, fp32, fp32
# 3. fp16, fp32, fp32
g_l
,
p_l
=
[],
[]
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
continue
if
p
.
grad
.
data
.
is_sparse
:
raise
RuntimeError
(
'FusedSGD does not support sparse gradients'
)
g_l
.
append
(
p
.
grad
)
p_l
.
append
(
p
)
m_l
,
first_run
=
self
.
get_momentums
(
p_l
)
multi_tensor_applier
(
self
.
multi_tensor_sgd
,
self
.
_dummy_overflow_buf
,
[
g_l
,
p_l
,
m_l
],
weight_decay
,
momentum
,
dampening
,
group
[
'lr'
],
nesterov
,
first_run
,
self
.
wd_after_momentum
,
1.0
)
return
loss
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment