Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
199fa834
Unverified
Commit
199fa834
authored
Feb 23, 2022
by
Masaki Kozuki
Committed by
GitHub
Feb 23, 2022
Browse files
be more flexible (#1299)
parent
069ff336
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
2 deletions
+5
-2
apex/transformer/pipeline_parallel/schedules/common.py
apex/transformer/pipeline_parallel/schedules/common.py
+5
-2
No files found.
apex/transformer/pipeline_parallel/schedules/common.py
View file @
199fa834
...
@@ -3,6 +3,8 @@ from typing import Any, Callable, Dict, List, Tuple, Union, Optional, Sequence
...
@@ -3,6 +3,8 @@ from typing import Any, Callable, Dict, List, Tuple, Union, Optional, Sequence
import
torch
import
torch
from
torch.autograd.variable
import
Variable
from
torch.autograd.variable
import
Variable
from
apex.contrib.layer_norm.layer_norm
import
FastLayerNorm
from
apex.normalization.fused_layer_norm
import
FusedLayerNorm
from
apex.transformer
import
parallel_state
from
apex.transformer
import
parallel_state
from
apex.transformer.enums
import
ModelType
from
apex.transformer.enums
import
ModelType
from
apex.transformer.pipeline_parallel.utils
import
get_num_microbatches
from
apex.transformer.pipeline_parallel.utils
import
get_num_microbatches
...
@@ -119,18 +121,19 @@ def _calc_number_of_params(model: List[torch.nn.Module]) -> int:
...
@@ -119,18 +121,19 @@ def _calc_number_of_params(model: List[torch.nn.Module]) -> int:
def
_get_params_for_weight_decay_optimization
(
def
_get_params_for_weight_decay_optimization
(
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
*
,
no_weight_decay_modules
=
(
FastLayerNorm
,
FusedLayerNorm
),
)
->
Dict
[
str
,
torch
.
nn
.
Parameter
]:
)
->
Dict
[
str
,
torch
.
nn
.
Parameter
]:
"""Divide params into with-weight-decay and without-weight-decay groups.
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and biases will have no weight decay but the rest will.
Layernorms and biases will have no weight decay but the rest will.
"""
"""
modules
=
listify_model
(
model
)
modules
=
listify_model
(
model
)
from
apex.normalization.fused_layer_norm
import
FusedLayerNorm
# NOQA
weight_decay_params
=
{
'params'
:
[]}
weight_decay_params
=
{
'params'
:
[]}
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
for
module
in
modules
:
for
module
in
modules
:
for
module_
in
module
.
modules
():
for
module_
in
module
.
modules
():
if
isinstance
(
module_
,
FusedLayerNorm
):
if
isinstance
(
module_
,
no_weight_decay_modules
):
no_weight_decay_params
[
'params'
].
extend
(
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
p
in
list
(
module_
.
_parameters
.
values
())
[
p
for
p
in
list
(
module_
.
_parameters
.
values
())
if
p
is
not
None
])
if
p
is
not
None
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment