Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cd13b638
Unverified
Commit
cd13b638
authored
Apr 01, 2022
by
アマデウス
Committed by
GitHub
Apr 01, 2022
Browse files
[model checkpoint] reworked unified layers for ease of save/load states (#593)
parent
acae68eb
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
85 additions
and
120 deletions
+85
-120
colossalai/nn/layer/base_layer.py
colossalai/nn/layer/base_layer.py
+8
-0
colossalai/nn/layer/colossalai_layer/_utils.py
colossalai/nn/layer/colossalai_layer/_utils.py
+19
-0
colossalai/nn/layer/colossalai_layer/dropout.py
colossalai/nn/layer/colossalai_layer/dropout.py
+9
-8
colossalai/nn/layer/colossalai_layer/embedding.py
colossalai/nn/layer/colossalai_layer/embedding.py
+19
-43
colossalai/nn/layer/colossalai_layer/linear.py
colossalai/nn/layer/colossalai_layer/linear.py
+20
-41
colossalai/nn/layer/colossalai_layer/normalization.py
colossalai/nn/layer/colossalai_layer/normalization.py
+10
-28
No files found.
colossalai/nn/layer/base_layer.py
View file @
cd13b638
...
...
@@ -25,3 +25,11 @@ class ParallelLayer(nn.Module):
ParallelMode
.
PIPELINE
)
self
.
pipeline_parallel_size
=
1
if
not
gpc
.
is_initialized
(
ParallelMode
.
PIPELINE
)
else
gpc
.
get_world_size
(
ParallelMode
.
PIPELINE
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
)
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
!=
0
:
missing_keys
.
clear
()
unexpected_keys
.
clear
()
colossalai/nn/layer/colossalai_layer/_utils.py
View file @
cd13b638
import
torch.nn
as
nn
from
torch
import
Tensor
from
..parallel_2d._operation
import
split_tensor_2d
...
...
@@ -17,3 +18,21 @@ def partition_batch(input_) -> Tensor:
return
_parallel_split_batch
[
tensor_parallel_mode
](
input_
)
else
:
return
input_
class
ColossalaiModule
(
nn
.
Module
):
def
__init__
(
self
,
module
:
nn
.
Module
,
**
kwargs
):
super
().
__init__
()
# copy values
self
.
__dict__
=
module
.
__dict__
.
copy
()
# copy methods
for
name
,
attr
in
module
.
__class__
.
__dict__
.
items
():
if
name
not
in
[
'__init__'
,
'forward'
]
and
callable
(
attr
):
setattr
(
self
,
name
,
getattr
(
module
,
name
))
self
.
_forward_func
=
module
.
forward
for
k
,
v
in
kwargs
.
items
():
setattr
(
self
,
k
,
v
)
def
forward
(
self
,
*
args
):
return
self
.
_forward_func
(
*
args
)
colossalai/nn/layer/colossalai_layer/dropout.py
View file @
cd13b638
...
...
@@ -3,9 +3,10 @@ from colossalai.context import ParallelMode, seed
from
..parallel_1d
import
*
from
..utils
import
get_tensor_parallel_mode
from
._utils
import
ColossalaiModule
class
Dropout
(
nn
.
Module
):
class
Dropout
(
Colossalai
Module
):
"""Dropout layer of colossalai.
Args:
...
...
@@ -13,16 +14,16 @@ class Dropout(nn.Module):
inplace (bool, optional): whether to do dropout in-place, default to be False.
"""
def
__init__
(
self
,
p
:
float
=
0.5
,
inplace
:
bool
=
False
)
->
None
:
super
().
__init__
()
self
.
tensor_parallel
=
get_tensor_parallel_mode
()
if
self
.
tensor_parallel
==
'1d'
:
self
.
drop
=
Dropout1D
(
p
,
inplace
)
tensor_parallel
=
get_tensor_parallel_mode
()
if
tensor_parallel
==
"1d"
:
drop
=
Dropout1D
(
p
,
inplace
)
else
:
self
.
drop
=
nn
.
Dropout
(
p
,
inplace
)
drop
=
nn
.
Dropout
(
p
,
inplace
)
super
().
__init__
(
drop
,
tensor_parallel
=
tensor_parallel
)
def
forward
(
self
,
*
args
):
if
self
.
tensor_parallel
in
[
None
,
'1d'
]:
return
self
.
drop
(
*
args
)
return
self
.
_forward_func
(
*
args
)
else
:
with
seed
(
ParallelMode
.
TENSOR
):
return
self
.
drop
(
*
args
)
return
self
.
_forward_func
(
*
args
)
colossalai/nn/layer/colossalai_layer/embedding.py
View file @
cd13b638
...
...
@@ -5,14 +5,16 @@ from colossalai.utils import get_current_device
from
torch
import
dtype
,
nn
from
...
import
init
as
init
from
..parallel_1d
import
*
from
..parallel_2d
import
*
from
..parallel_2p5d
import
*
from
..parallel_3d
import
*
from
..parallel_1d
import
Embedding1D
,
PatchEmbedding1D
,
VocabParallelEmbedding1D
from
..parallel_2d
import
Embedding2D
,
PatchEmbedding2D
,
VocabParallelEmbedding2D
from
..parallel_2p5d
import
Embedding2p5D
,
PatchEmbedding2p5D
,
VocabParallelEmbedding2p5D
from
..parallel_3d
import
Embedding3D
,
PatchEmbedding3D
,
VocabParallelEmbedding3D
from
..utils
import
get_tensor_parallel_mode
from
..vanilla
import
*
from
..vanilla
import
VanillaPatchEmbedding
from
._utils
import
ColossalaiModule
_parallel_embedding
=
{
'1d'
:
Embedding1D
,
'2d'
:
Embedding2D
,
'2.5d'
:
Embedding2p5D
,
'3d'
:
Embedding3D
,
...
...
@@ -27,14 +29,14 @@ _vocab_parallel_embedding = {
_parallel_patchembedding
=
{
None
:
VanillaPatchEmbedding
,
'1d'
:
Vanilla
PatchEmbedding
,
'1d'
:
PatchEmbedding
1D
,
'2d'
:
PatchEmbedding2D
,
'2.5d'
:
PatchEmbedding2p5D
,
'3d'
:
PatchEmbedding3D
}
class
Embedding
(
nn
.
Module
):
class
Embedding
(
Colossalai
Module
):
r
"""Embedding for colossalai.
Args:
...
...
@@ -73,14 +75,13 @@ class Embedding(nn.Module):
vocab_parallel_limit
:
int
=
2048
,
*
args
,
**
kwargs
)
->
None
:
super
().
__init__
()
tensor_parallel
=
get_tensor_parallel_mode
()
if
tensor_parallel
is
None
or
(
tensor_parallel
==
'1d'
and
num_embeddings
<=
vocab_parallel_limit
)
:
self
.
embed
=
nn
.
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
padding_idx
,
*
args
,
**
kwargs
).
to
(
dtype
).
to
(
get_current_device
())
weight_initializer
(
self
.
embed
.
weight
,
fan_in
=
num_embeddings
,
fan_out
=
embedding_dim
)
if
tensor_parallel
is
None
:
embed
=
nn
.
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
padding_idx
,
*
args
,
**
kwargs
).
to
(
dtype
).
to
(
get_current_device
())
weight_initializer
(
embed
.
weight
,
fan_in
=
num_embeddings
,
fan_out
=
embedding_dim
)
elif
num_embeddings
<=
vocab_parallel_limit
:
self
.
embed
=
_parallel_embedding
[
tensor_parallel
](
embed
=
_parallel_embedding
[
tensor_parallel
](
num_embeddings
,
embedding_dim
,
padding_idx
=
padding_idx
,
...
...
@@ -90,7 +91,7 @@ class Embedding(nn.Module):
**
kwargs
,
)
else
:
self
.
embed
=
_vocab_parallel_embedding
[
tensor_parallel
](
embed
=
_vocab_parallel_embedding
[
tensor_parallel
](
num_embeddings
,
embedding_dim
,
padding_idx
=
padding_idx
,
...
...
@@ -99,16 +100,10 @@ class Embedding(nn.Module):
*
args
,
**
kwargs
,
)
super
().
__init__
(
embed
)
@
property
def
weight
(
self
):
return
self
.
embed
.
weight
def
forward
(
self
,
*
args
):
return
self
.
embed
(
*
args
)
class
PatchEmbedding
(
nn
.
Module
):
class
PatchEmbedding
(
ColossalaiModule
):
"""2D Image to Patch Embedding.
Args:
...
...
@@ -141,9 +136,8 @@ class PatchEmbedding(nn.Module):
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
),
position_embed_initializer
:
Callable
=
init
.
zeros_
()
)
->
None
:
super
().
__init__
()
tensor_parallel
=
get_tensor_parallel_mode
()
self
.
embed
=
_parallel_patchembedding
[
tensor_parallel
](
embed
=
_parallel_patchembedding
[
tensor_parallel
](
img_size
,
patch_size
,
in_chans
,
...
...
@@ -154,22 +148,4 @@ class PatchEmbedding(nn.Module):
bias_initializer
=
bias_initializer
,
position_embed_initializer
=
position_embed_initializer
,
)
@
property
def
weight
(
self
):
return
self
.
embed
.
weight
@
property
def
bias
(
self
):
return
self
.
embed
.
bias
@
property
def
pos_embed
(
self
):
return
self
.
embed
.
pos_embed
@
property
def
cls_token
(
self
):
return
self
.
embed
.
cls_token
def
forward
(
self
,
*
args
):
return
self
.
embed
(
*
args
)
super
().
__init__
(
embed
)
colossalai/nn/layer/colossalai_layer/linear.py
View file @
cd13b638
...
...
@@ -12,6 +12,7 @@ from ..parallel_2p5d import *
from
..parallel_3d
import
*
from
..utils
import
get_tensor_parallel_mode
from
..vanilla
import
*
from
._utils
import
ColossalaiModule
_parallel_linear
=
{
'1d'
:
Linear1D
,
'2d'
:
Linear2D
,
'2.5d'
:
Linear2p5D
,
'3d'
:
Linear3D
}
...
...
@@ -31,7 +32,7 @@ _vocab_parallel_classifier = {
}
class
Linear
(
nn
.
Module
):
class
Linear
(
Colossalai
Module
):
"""Linear layer of colossalai.
Args:
...
...
@@ -71,41 +72,30 @@ class Linear(nn.Module):
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
),
**
kwargs
)
->
None
:
super
().
__init__
()
tensor_parallel
=
get_tensor_parallel_mode
()
if
tensor_parallel
is
None
:
self
.
layer
=
nn
.
Linear
(
in_features
,
out_features
,
bias
=
bias
).
to
(
dtype
).
to
(
get_current_device
())
weight_initializer
(
self
.
layer
.
weight
,
fan_in
=
in_features
,
fan_out
=
out_features
)
if
self
.
layer
.
bias
is
not
None
:
bias_initializer
(
self
.
layer
.
bias
,
fan_in
=
in_features
)
layer
=
nn
.
Linear
(
in_features
,
out_features
,
bias
=
bias
).
to
(
dtype
).
to
(
get_current_device
())
weight_initializer
(
layer
.
weight
,
fan_in
=
in_features
,
fan_out
=
out_features
)
if
layer
.
bias
is
not
None
:
bias_initializer
(
layer
.
bias
,
fan_in
=
in_features
)
else
:
linear_cls
=
_parallel_linear
[
tensor_parallel
]
gather_output
=
kwargs
.
pop
(
'gather_output'
,
None
)
if
'gather_output'
in
inspect
.
signature
(
linear_cls
.
__init__
).
parameters
.
keys
():
# gather_out arg is available
kwargs
[
'gather_output'
]
=
gather_output
self
.
layer
=
linear_cls
(
in_features
,
out_features
,
bias
=
bias
,
dtype
=
dtype
,
weight_initializer
=
weight_initializer
,
bias_initializer
=
bias_initializer
,
**
kwargs
,
)
@
property
def
weight
(
self
):
return
self
.
layer
.
weight
@
property
def
bias
(
self
):
return
self
.
layer
.
bias
def
forward
(
self
,
*
args
):
return
self
.
layer
(
*
args
)
layer
=
linear_cls
(
in_features
,
out_features
,
bias
=
bias
,
dtype
=
dtype
,
weight_initializer
=
weight_initializer
,
bias_initializer
=
bias_initializer
,
**
kwargs
,
)
super
().
__init__
(
layer
)
class
Classifier
(
nn
.
Module
):
class
Classifier
(
Colossalai
Module
):
"""Classifier layer of colossalai.
Args:
...
...
@@ -132,10 +122,9 @@ class Classifier(nn.Module):
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
),
vocab_parallel_limit
:
int
=
2048
)
->
None
:
super
().
__init__
()
tensor_parallel
=
get_tensor_parallel_mode
()
if
num_classes
<=
vocab_parallel_limit
or
tensor_parallel
is
None
:
self
.
layer
=
_parallel_classifier
[
tensor_parallel
](
layer
=
_parallel_classifier
[
tensor_parallel
](
in_features
,
num_classes
,
weight
=
weight
,
...
...
@@ -145,7 +134,7 @@ class Classifier(nn.Module):
bias_initializer
=
bias_initializer
,
)
else
:
self
.
layer
=
_vocab_parallel_classifier
[
tensor_parallel
](
layer
=
_vocab_parallel_classifier
[
tensor_parallel
](
in_features
,
num_classes
,
weight
=
weight
,
...
...
@@ -154,14 +143,4 @@ class Classifier(nn.Module):
weight_initializer
=
weight_initializer
,
bias_initializer
=
bias_initializer
,
)
@
property
def
weight
(
self
):
return
self
.
layer
.
weight
@
property
def
bias
(
self
):
return
self
.
layer
.
bias
def
forward
(
self
,
*
args
):
return
self
.
layer
(
*
args
)
super
().
__init__
(
layer
)
colossalai/nn/layer/colossalai_layer/normalization.py
View file @
cd13b638
from
colossalai.utils
import
get_current_device
from
torch
import
nn
from
colossalai
import
kernel
from
...
import
init
as
init
from
..parallel_1d
import
*
from
..parallel_2d
import
*
from
..parallel_2p5d
import
*
from
..parallel_3d
import
*
from
..parallel_1d
import
LayerNorm1D
from
..parallel_2d
import
LayerNorm2D
from
..parallel_2p5d
import
LayerNorm2p5D
from
..parallel_3d
import
LayerNorm3D
from
..utils
import
get_tensor_parallel_mode
from
.
.vanilla
import
*
from
.
_utils
import
ColossalaiModule
_parallel_layernorm
=
{
'1d'
:
kernel
.
LayerNorm
,
'2d'
:
LayerNorm2D
,
'2.5d'
:
LayerNorm2p5D
,
'3d'
:
LayerNorm3D
}
_parallel_layernorm
=
{
'1d'
:
LayerNorm1D
,
'2d'
:
LayerNorm2D
,
'2.5d'
:
LayerNorm2p5D
,
'3d'
:
LayerNorm3D
}
class
LayerNorm
(
nn
.
Module
):
class
LayerNorm
(
Colossalai
Module
):
r
"""Layer Normalization for colossalai.
Args:
...
...
@@ -31,20 +24,9 @@ class LayerNorm(nn.Module):
"""
def
__init__
(
self
,
normalized_shape
:
int
,
eps
=
1e-05
,
dtype
=
None
)
->
None
:
super
().
__init__
()
tensor_parallel
=
get_tensor_parallel_mode
()
if
tensor_parallel
is
None
:
self
.
norm
=
nn
.
LayerNorm
(
normalized_shape
,
eps
=
eps
).
to
(
dtype
).
to
(
get_current_device
())
norm
=
nn
.
LayerNorm
(
normalized_shape
,
eps
=
eps
).
to
(
dtype
).
to
(
get_current_device
())
else
:
self
.
norm
=
_parallel_layernorm
[
tensor_parallel
](
normalized_shape
,
eps
=
eps
,
dtype
=
dtype
)
@
property
def
weight
(
self
):
return
self
.
norm
.
weight
@
property
def
bias
(
self
):
return
self
.
norm
.
bias
def
forward
(
self
,
*
args
):
return
self
.
norm
(
*
args
)
norm
=
_parallel_layernorm
[
tensor_parallel
](
normalized_shape
,
eps
=
eps
,
dtype
=
dtype
)
super
().
__init__
(
norm
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment