Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c50bfb80
Unverified
Commit
c50bfb80
authored
Apr 01, 2022
by
アマデウス
Committed by
GitHub
Apr 01, 2022
Browse files
[model checkpoint] updated saving/loading for 1d layers (#594)
parent
7636d518
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
371 additions
and
30 deletions
+371
-30
colossalai/nn/layer/parallel_1d/__init__.py
colossalai/nn/layer/parallel_1d/__init__.py
+3
-3
colossalai/nn/layer/parallel_1d/layers.py
colossalai/nn/layer/parallel_1d/layers.py
+368
-27
No files found.
colossalai/nn/layer/parallel_1d/__init__.py
View file @
c50bfb80
from
.layers
import
(
Classifier1D
,
Dropout1D
,
Embedding1D
,
Linear1D
,
Linear1D_Col
,
Linear1D_Row
,
from
.layers
import
(
Classifier1D
,
Dropout1D
,
Embedding1D
,
LayerNorm1D
,
Linear1D
,
Linear1D_Col
,
Linear1D_Row
,
VocabParallelClassifier1D
,
VocabParallelEmbedding1D
)
PatchEmbedding1D
,
VocabParallelClassifier1D
,
VocabParallelEmbedding1D
)
__all__
=
[
__all__
=
[
'Linear1D'
,
'Linear1D_Col'
,
'Linear1D_Row'
,
'Embedding1D'
,
'Dropout1D'
,
'Classifier1D'
,
'VocabParallelClassifier1D'
,
'Linear1D'
,
'Linear1D_Col'
,
'Linear1D_Row'
,
'Embedding1D'
,
'Dropout1D'
,
'Classifier1D'
,
'VocabParallelClassifier1D'
,
'VocabParallelEmbedding1D'
'VocabParallelEmbedding1D'
,
'LayerNorm1D'
,
'PatchEmbedding1D'
]
]
colossalai/nn/layer/parallel_1d/layers.py
View file @
c50bfb80
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# -*- encoding: utf-8 -*-
# -*- encoding: utf-8 -*-
import
math
import
math
from
collections
import
OrderedDict
from
typing
import
Callable
,
Tuple
from
typing
import
Callable
,
Tuple
import
torch
import
torch
...
@@ -10,20 +11,25 @@ from colossalai.communication import broadcast
...
@@ -10,20 +11,25 @@ from colossalai.communication import broadcast
from
colossalai.context
import
ParallelMode
,
seed
from
colossalai.context
import
ParallelMode
,
seed
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.global_variables
import
tensor_parallel_env
as
env
from
colossalai.global_variables
import
tensor_parallel_env
as
env
from
colossalai.kernel
import
LayerNorm
from
colossalai.nn
import
init
as
init
from
colossalai.nn
import
init
as
init
from
colossalai.registry
import
LAYERS
from
colossalai.registry
import
LAYERS
from
colossalai.utils.checkpointing
import
(
broadcast_state_dict
,
gather_tensor_parallel_state_dict
,
partition_tensor_parallel_state_dict
)
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
torch
import
Tensor
from
torch
import
Tensor
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
..vanilla
import
VanillaPatchEmbedding
from
..base_layer
import
ParallelLayer
from
..base_layer
import
ParallelLayer
from
..colossalai_layer._utils
import
ColossalaiModule
from
..utils
import
divide
,
set_tensor_parallel_attribute_by_partition
from
..utils
import
divide
,
set_tensor_parallel_attribute_by_partition
from
._utils
import
(
gather_forward_split_backward
,
get_parallel_input
,
reduce_grad
,
reduce_input
,
set_parallel_input
,
from
._utils
import
(
gather_forward_split_backward
,
get_parallel_input
,
reduce_grad
,
reduce_input
,
set_parallel_input
,
split_forward_gather_backward
)
split_forward_gather_backward
)
@
LAYERS
.
register_module
@
LAYERS
.
register_module
class
Linear1D
(
torch
.
nn
.
Module
):
class
Linear1D
(
Colossalai
Module
):
r
"""Linear layer for 1D parallelism.
r
"""Linear layer for 1D parallelism.
Args:
Args:
...
@@ -52,37 +58,69 @@ class Linear1D(torch.nn.Module):
...
@@ -52,37 +58,69 @@ class Linear1D(torch.nn.Module):
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
)):
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
)):
super
().
__init__
()
parallel_input
=
get_parallel_input
()
parallel_input
=
get_parallel_input
()
if
not
parallel_input
:
if
not
parallel_input
:
self
.
layer
=
Linear1D_Col
(
in_features
,
layer
=
Linear1D_Col
(
in_features
,
out_features
,
out_features
,
bias
=
bias
,
bias
=
bias
,
dtype
=
dtype
,
dtype
=
dtype
,
gather_output
=
gather_output
,
gather_output
=
gather_output
,
skip_bias_add
=
skip_bias_add
,
skip_bias_add
=
skip_bias_add
,
weight_initializer
=
weight_initializer
,
weight_initializer
=
weight_initializer
,
bias_initializer
=
bias_initializer
)
bias_initializer
=
bias_initializer
)
else
:
else
:
self
.
layer
=
Linear1D_Row
(
in_features
,
layer
=
Linear1D_Row
(
in_features
,
out_features
,
out_features
,
bias
=
bias
,
bias
=
bias
,
dtype
=
dtype
,
dtype
=
dtype
,
parallel_input
=
parallel_input
,
parallel_input
=
parallel_input
,
skip_bias_add
=
skip_bias_add
,
skip_bias_add
=
skip_bias_add
,
weight_initializer
=
weight_initializer
,
weight_initializer
=
weight_initializer
,
bias_initializer
=
bias_initializer
)
bias_initializer
=
bias_initializer
)
super
().
__init__
(
layer
)
@
property
def
weight
(
self
):
return
self
.
layer
.
weight
@
property
@
LAYERS
.
register_module
def
bias
(
self
):
class
LayerNorm1D
(
ColossalaiModule
):
return
self
.
layer
.
bias
r
"""
Layer Normalization for colossalai
:param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int
:param eps: a value added to the denominator for numerical stability, defaults to 1e-05
:type eps: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
def
__init__
(
self
,
normalized_shape
:
int
,
eps
=
1e-05
,
dtype
=
None
):
return
self
.
layer
(
input_
)
norm
=
LayerNorm
(
normalized_shape
,
eps
=
eps
,
device
=
get_current_device
(),
dtype
=
dtype
)
super
().
__init__
(
norm
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
# bias
bias
=
state_dict
.
pop
(
bias_key
,
None
)
if
bias
is
not
None
:
local_state
[
bias_key
]
=
bias
local_state
=
broadcast_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_1D
)
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
super
().
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
@
LAYERS
.
register_module
@
LAYERS
.
register_module
...
@@ -153,6 +191,55 @@ class Classifier1D(ParallelLayer):
...
@@ -153,6 +191,55 @@ class Classifier1D(ParallelLayer):
num_partition
=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
num_partition
=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
set_tensor_parallel_attribute_by_partition
(
self
.
weight
,
num_partition
)
set_tensor_parallel_attribute_by_partition
(
self
.
weight
,
num_partition
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
if
self
.
has_weight
:
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
# bias
if
self
.
bias
is
not
None
:
bias
=
state_dict
.
pop
(
bias_key
,
None
)
if
bias
is
not
None
:
local_state
[
bias_key
]
=
bias
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_1D
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
False
})
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
local_state
=
OrderedDict
()
if
self
.
has_weight
:
local_state
[
weight_key
]
=
self
.
weight
if
self
.
bias
is
not
None
:
local_state
[
bias_key
]
=
self
.
bias
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_1D
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
False
},
keep_vars
=
keep_vars
)
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
# Set up backprop all-reduce.
# Set up backprop all-reduce.
if
self
.
parallel_input
:
if
self
.
parallel_input
:
...
@@ -241,6 +328,55 @@ class VocabParallelClassifier1D(ParallelLayer):
...
@@ -241,6 +328,55 @@ class VocabParallelClassifier1D(ParallelLayer):
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
set_tensor_parallel_attribute_by_partition
(
self
.
bias
,
num_partition
)
set_tensor_parallel_attribute_by_partition
(
self
.
bias
,
num_partition
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
if
self
.
has_weight
:
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
# bias
if
self
.
bias
is
not
None
:
bias
=
state_dict
.
pop
(
bias_key
,
None
)
if
bias
is
not
None
:
local_state
[
bias_key
]
=
bias
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_1D
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
})
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
local_state
=
OrderedDict
()
if
self
.
has_weight
:
local_state
[
weight_key
]
=
self
.
weight
if
self
.
bias
is
not
None
:
local_state
[
bias_key
]
=
self
.
bias
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_1D
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
},
keep_vars
=
keep_vars
)
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
assert
input_
.
shape
[
-
1
]
==
self
.
weight
.
shape
[
-
1
],
\
assert
input_
.
shape
[
-
1
]
==
self
.
weight
.
shape
[
-
1
],
\
'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'
.
format
(
'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'
.
format
(
...
@@ -328,6 +464,52 @@ class Linear1D_Col(ParallelLayer):
...
@@ -328,6 +464,52 @@ class Linear1D_Col(ParallelLayer):
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
set_tensor_parallel_attribute_by_partition
(
self
.
bias
,
num_partition
)
set_tensor_parallel_attribute_by_partition
(
self
.
bias
,
num_partition
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
# bias
if
self
.
bias
is
not
None
:
bias
=
state_dict
.
pop
(
bias_key
,
None
)
if
bias
is
not
None
:
local_state
[
bias_key
]
=
bias
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_1D
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
})
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
})
if
self
.
bias
is
not
None
:
local_state
[
bias_key
]
=
self
.
bias
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_1D
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
},
keep_vars
=
keep_vars
)
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tuple
[
Tensor
,
Tensor
]:
def
forward
(
self
,
input_
:
Tensor
)
->
Tuple
[
Tensor
,
Tensor
]:
assert
input_
.
shape
[
-
1
]
==
self
.
weight
.
shape
[
-
1
],
\
assert
input_
.
shape
[
-
1
]
==
self
.
weight
.
shape
[
-
1
],
\
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'
.
format
(
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'
.
format
(
...
@@ -420,6 +602,52 @@ class Linear1D_Row(ParallelLayer):
...
@@ -420,6 +602,52 @@ class Linear1D_Row(ParallelLayer):
num_partition
=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
num_partition
=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
set_tensor_parallel_attribute_by_partition
(
self
.
weight
,
num_partition
)
set_tensor_parallel_attribute_by_partition
(
self
.
weight
,
num_partition
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
# bias
if
self
.
bias
is
not
None
:
bias
=
state_dict
.
pop
(
bias_key
,
None
)
if
bias
is
not
None
:
local_state
[
bias_key
]
=
bias
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_1D
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
False
})
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
})
if
self
.
bias
is
not
None
:
local_state
[
bias_key
]
=
self
.
bias
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_1D
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
False
},
keep_vars
=
keep_vars
)
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
# Set up backprop all-reduce.
# Set up backprop all-reduce.
if
self
.
parallel_input
:
if
self
.
parallel_input
:
...
@@ -514,6 +742,31 @@ class Embedding1D(ParallelLayer):
...
@@ -514,6 +742,31 @@ class Embedding1D(ParallelLayer):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
self
.
weight
[
self
.
padding_idx
].
fill_
(
0
)
self
.
weight
[
self
.
padding_idx
].
fill_
(
0
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_1D
,
dims
=
{
weight_key
:
-
1
},
partition_states
=
{
weight_key
:
True
})
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
})
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_1D
,
dims
=
{
weight_key
:
-
1
},
partition_states
=
{
weight_key
:
True
},
keep_vars
=
keep_vars
)
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
output_parallel
=
F
.
embedding
(
input_
,
self
.
weight
,
self
.
padding_idx
,
*
self
.
embed_args
,
**
self
.
embed_kwargs
)
output_parallel
=
F
.
embedding
(
input_
,
self
.
weight
,
self
.
padding_idx
,
*
self
.
embed_args
,
**
self
.
embed_kwargs
)
...
@@ -594,10 +847,35 @@ class VocabParallelEmbedding1D(torch.nn.Module):
...
@@ -594,10 +847,35 @@ class VocabParallelEmbedding1D(torch.nn.Module):
def
_fill_padding_idx_with_zero
(
self
)
->
None
:
def
_fill_padding_idx_with_zero
(
self
)
->
None
:
if
self
.
padding_idx
is
not
None
and
\
if
self
.
padding_idx
is
not
None
and
\
self
.
padding_idx
>=
self
.
vocab_start_index
and
self
.
padding_idx
<
self
.
vocab_end_index
:
self
.
padding_idx
>=
self
.
vocab_start_index
and
self
.
padding_idx
<
self
.
vocab_end_index
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
self
.
weight
[
self
.
padding_idx
-
self
.
vocab_start_index
].
fill_
(
0
)
self
.
weight
[
self
.
padding_idx
-
self
.
vocab_start_index
].
fill_
(
0
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_1D
,
dims
=
{
weight_key
:
0
},
partition_states
=
{
weight_key
:
True
})
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
})
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_1D
,
dims
=
{
weight_key
:
0
},
partition_states
=
{
weight_key
:
True
},
keep_vars
=
keep_vars
)
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
# Build the mask.
# Build the mask.
input_mask
=
(
input_
<
self
.
vocab_start_index
)
|
(
input_
>=
self
.
vocab_end_index
)
input_mask
=
(
input_
<
self
.
vocab_start_index
)
|
(
input_
>=
self
.
vocab_end_index
)
...
@@ -637,3 +915,66 @@ class Dropout1D(ParallelLayer):
...
@@ -637,3 +915,66 @@ class Dropout1D(ParallelLayer):
else
:
else
:
output
=
F
.
dropout
(
input_
,
self
.
p
,
self
.
training
,
self
.
inplace
)
output
=
F
.
dropout
(
input_
,
self
.
p
,
self
.
training
,
self
.
inplace
)
return
output
return
output
@
LAYERS
.
register_module
class
PatchEmbedding1D
(
ColossalaiModule
):
"""
2D Image to Patch Embedding
:param img_size: image size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param in_chans: number of channels of input image
:type in_chans: int
:param embed_size: size of embedding
:type embed_size: int
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
:param position_embed_initializer: The intializer of position embedding, defaults to zero
:type position_embed_initializer: typing.Callable, optional
"""
def
__init__
(
self
,
img_size
:
int
,
patch_size
:
int
,
in_chans
:
int
,
embed_size
:
int
,
dtype
:
torch
.
dtype
=
None
,
flatten
:
bool
=
True
,
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
),
position_embed_initializer
:
Callable
=
init
.
zeros_
()):
embed
=
VanillaPatchEmbedding
(
img_size
,
patch_size
,
in_chans
,
embed_size
,
dtype
=
dtype
,
flatten
=
flatten
,
weight_initializer
=
weight_initializer
,
bias_initializer
=
bias_initializer
,
position_embed_initializer
=
position_embed_initializer
)
super
().
__init__
(
embed
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
):
local_state
=
OrderedDict
()
param_keys
=
[
prefix
+
'weight'
,
prefix
+
'bias'
,
prefix
+
'cls_token'
,
prefix
+
'pos_embed'
]
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
for
key
in
param_keys
:
param
=
state_dict
.
pop
(
key
,
None
)
if
param
is
not
None
:
local_state
[
key
]
=
param
local_state
=
broadcast_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_1D
)
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
super
().
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment