Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
93089ed7
Unverified
Commit
93089ed7
authored
Apr 01, 2022
by
アマデウス
Committed by
GitHub
Apr 01, 2022
Browse files
[model checkpoint] updated saving/loading for 2.5d layers (#596)
parent
6302069c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
537 additions
and
7 deletions
+537
-7
colossalai/nn/layer/parallel_2p5d/layers.py
colossalai/nn/layer/parallel_2p5d/layers.py
+537
-7
No files found.
colossalai/nn/layer/parallel_2p5d/layers.py
View file @
93089ed7
import
math
from
collections
import
OrderedDict
from
typing
import
Callable
import
torch
...
...
@@ -10,13 +11,15 @@ from colossalai.core import global_context as gpc
from
colossalai.global_variables
import
tensor_parallel_env
as
env
from
colossalai.nn
import
init
as
init
from
colossalai.registry
import
LAYERS
from
colossalai.utils.checkpointing
import
(
broadcast_state_dict
,
gather_tensor_parallel_state_dict
,
partition_tensor_parallel_state_dict
)
from
colossalai.utils.cuda
import
get_current_device
from
torch
import
Tensor
from
torch.nn
import
Parameter
from
..base_layer
import
ParallelLayer
from
..utils
import
divide
,
set_tensor_parallel_attribute_by_partition
,
to_2tuple
from
._operation
import
(
add_bias_2p5d
,
Matmul_AB_2p5D
,
Matmul_ABT_2p5D
,
all_gather_tensor_2p5d
,
classifier_2p5d
,
from
._operation
import
(
Matmul_AB_2p5D
,
Matmul_ABT_2p5D
,
add_bias_2p5d
,
all_gather_tensor_2p5d
,
classifier_2p5d
,
layernorm_2p5d
,
reduce_scatter_tensor_2p5d
,
split_tensor_2p5d
)
from
._utils
import
assert_tesseract_initialization
,
get_tesseract_dim_dep_from_env
...
...
@@ -40,6 +43,7 @@ class Linear2p5D(ParallelLayer):
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
...
...
@@ -92,6 +96,96 @@ class Linear2p5D(ParallelLayer):
if
self
.
bias
is
not
None
:
bias_initializer
(
self
.
bias
,
fan_in
=
fan_in
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
.
transpose
(
0
,
1
)
# bias
if
self
.
bias
is
not
None
:
bias
=
state_dict
.
pop
(
bias_key
,
None
)
if
bias
is
not
None
:
local_state
[
bias_key
]
=
bias
# broadcast in dep groups
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_COL
)
==
0
and
\
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_ROW
)
==
0
:
broadcast_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2P5D_DEP
)
# partition in column groups
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_ROW
)
==
0
:
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2P5D_COL
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
False
},
)
# partition in row groups
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2P5D_ROW
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
},
)
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
,
**
kwargs
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_DEP
)
==
0
:
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
})
if
self
.
bias
is
not
None
:
local_state
[
bias_key
]
=
self
.
bias
# gather in row groups
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2P5D_ROW
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
},
keep_vars
=
keep_vars
,
)
# gather in column groups
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_ROW
)
==
0
:
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2P5D_COL
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
False
},
keep_vars
=
keep_vars
,
)
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
local_state
[
weight_key
]
=
local_state
[
weight_key
].
transpose
(
0
,
1
)
destination
.
update
(
local_state
)
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
# input: [m/dq, n/q, k/q]
# output: [m/dq, n/q, h/q]
...
...
@@ -143,6 +237,7 @@ class LayerNorm2p5D(ParallelLayer):
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
def
__init__
(
self
,
normalized_shape
:
int
,
eps
:
float
=
1e-05
,
dtype
=
None
):
super
().
__init__
()
...
...
@@ -163,14 +258,95 @@ class LayerNorm2p5D(ParallelLayer):
# create parameters
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
self
.
gamma
=
Parameter
(
torch
.
ones
(
self
.
partitioned_partition
,
**
factory_kwargs
))
self
.
b
eta
=
Parameter
(
torch
.
zeros
(
self
.
partitioned_partition
,
**
factory_kwargs
))
self
.
weight
=
Parameter
(
torch
.
ones
(
self
.
partitioned_partition
,
**
factory_kwargs
))
self
.
b
ias
=
Parameter
(
torch
.
zeros
(
self
.
partitioned_partition
,
**
factory_kwargs
))
self
.
_set_tensor_parallel_attribute
()
def
_set_tensor_parallel_attribute
(
self
):
set_tensor_parallel_attribute_by_partition
(
self
.
gamma
,
self
.
tesseract_dim
)
set_tensor_parallel_attribute_by_partition
(
self
.
beta
,
self
.
tesseract_dim
)
set_tensor_parallel_attribute_by_partition
(
self
.
weight
,
self
.
tesseract_dim
)
set_tensor_parallel_attribute_by_partition
(
self
.
bias
,
self
.
tesseract_dim
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
# bias
bias
=
state_dict
.
pop
(
bias_key
,
None
)
if
bias
is
not
None
:
local_state
[
bias_key
]
=
bias
# partition in row groups
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_COL
)
==
0
:
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2P5D_ROW
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
},
)
# partition in column groups
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2P5D_COL
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
},
)
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
,
**
kwargs
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
,
bias_key
:
self
.
bias
})
# gather in column groups
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2P5D_COL
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
},
keep_vars
=
keep_vars
,
)
# gather in row groups
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_COL
)
==
0
:
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2P5D_ROW
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
},
keep_vars
=
keep_vars
,
)
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
destination
.
update
(
local_state
)
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
with
torch
.
no_grad
():
...
...
@@ -188,11 +364,11 @@ class LayerNorm2p5D(ParallelLayer):
Var_x
=
1.0
/
torch
.
sqrt
(
Var_x
+
self
.
variance_epsilon
)
output
=
layernorm_2p5d
(
x
,
E_x
,
Var_x
,
self
.
normalized_shape
,
ParallelMode
.
PARALLEL_2P5D_ROW
)
bias
=
add_bias_2p5d
(
None
,
self
.
b
eta
,
self
.
partitioned_partition
,
self
.
tesseract_dim
,
self
.
row_rank
,
bias
=
add_bias_2p5d
(
None
,
self
.
b
ias
,
self
.
partitioned_partition
,
self
.
tesseract_dim
,
self
.
row_rank
,
self
.
col_rank
,
self
.
dep_rank
,
ParallelMode
.
PARALLEL_2P5D_COL
,
True
,
self
.
data_parallel_rank
,
self
.
pipeline_parallel_rank
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
)
scale
=
add_bias_2p5d
(
None
,
self
.
gamma
,
self
.
partitioned_partition
,
self
.
tesseract_dim
,
self
.
row_rank
,
scale
=
add_bias_2p5d
(
None
,
self
.
weight
,
self
.
partitioned_partition
,
self
.
tesseract_dim
,
self
.
row_rank
,
self
.
col_rank
,
self
.
dep_rank
,
ParallelMode
.
PARALLEL_2P5D_COL
,
True
,
self
.
data_parallel_rank
,
self
.
pipeline_parallel_rank
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
)
...
...
@@ -221,6 +397,7 @@ class PatchEmbedding2p5D(ParallelLayer):
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def
__init__
(
self
,
img_size
:
int
,
patch_size
:
int
,
...
...
@@ -276,6 +453,120 @@ class PatchEmbedding2p5D(ParallelLayer):
bias_initializer
(
self
.
bias
,
fan_in
=
fan_in
)
position_embed_initializer
(
self
.
pos_embed
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
cls_token_key
=
prefix
+
'cls_token'
pos_embed_key
=
prefix
+
'pos_embed'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
# bias
bias
=
state_dict
.
pop
(
bias_key
,
None
)
if
bias
is
not
None
:
local_state
[
bias_key
]
=
bias
# cls token
cls_token
=
state_dict
.
pop
(
cls_token_key
,
None
)
if
cls_token
is
not
None
:
local_state
[
cls_token_key
]
=
cls_token
# pos embed
pos_embed
=
state_dict
.
pop
(
pos_embed_key
,
None
)
if
pos_embed
is
not
None
:
local_state
[
pos_embed_key
]
=
pos_embed
# partition in row groups
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_COL
)
==
0
:
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2P5D_ROW
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
,
cls_token_key
:
-
1
,
pos_embed_key
:
-
1
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
,
cls_token_key
:
True
,
pos_embed_key
:
True
},
)
# partition in column groups
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2P5D_COL
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
,
cls_token_key
:
-
1
,
pos_embed_key
:
-
1
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
,
cls_token_key
:
True
,
pos_embed_key
:
True
},
)
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
,
**
kwargs
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
cls_token_key
=
prefix
+
'cls_token'
pos_embed_key
=
prefix
+
'pos_embed'
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
,
bias_key
:
self
.
bias
,
cls_token_key
:
self
.
cls_token
,
pos_embed_key
:
self
.
pos_embed
})
# gather in column groups
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2P5D_COL
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
,
cls_token_key
:
-
1
,
pos_embed_key
:
-
1
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
,
cls_token_key
:
True
,
pos_embed_key
:
True
},
keep_vars
=
keep_vars
,
)
# gather in row groups
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_COL
)
==
0
:
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2P5D_ROW
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
,
cls_token_key
:
-
1
,
pos_embed_key
:
-
1
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
,
cls_token_key
:
True
,
pos_embed_key
:
True
},
keep_vars
=
keep_vars
,
)
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
input_
=
split_tensor_2p5d
(
input_
,
0
)
...
...
@@ -329,6 +620,7 @@ class Embedding2p5D(ParallelLayer):
More details about initializer please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_
"""
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
...
...
@@ -369,6 +661,57 @@ class Embedding2p5D(ParallelLayer):
with
torch
.
no_grad
():
self
.
weight
[
self
.
padding_idx
].
fill_
(
0
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
# partition in row groups
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_COL
)
==
0
:
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2P5D_ROW
,
dims
=
{
weight_key
:
-
1
},
partition_states
=
{
weight_key
:
True
},
)
# partition in column groups
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2P5D_COL
,
dims
=
{
weight_key
:
-
1
},
partition_states
=
{
weight_key
:
True
},
)
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
,
**
kwargs
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
})
# gather in column groups
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2P5D_COL
,
dims
=
{
weight_key
:
-
1
},
partition_states
=
{
weight_key
:
True
},
keep_vars
=
keep_vars
,
)
# gather in row groups
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_COL
)
==
0
:
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2P5D_ROW
,
dims
=
{
weight_key
:
-
1
},
partition_states
=
{
weight_key
:
True
},
keep_vars
=
keep_vars
,
)
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
input_
=
split_tensor_2p5d
(
input_
,
0
)
...
...
@@ -409,6 +752,7 @@ class VocabParallelEmbedding2p5D(torch.nn.Module):
More details about initializer please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
...
...
@@ -456,6 +800,57 @@ class VocabParallelEmbedding2p5D(torch.nn.Module):
with
torch
.
no_grad
():
self
.
weight
[
self
.
padding_idx
-
self
.
vocab_start_index
].
fill_
(
0
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
# partition in row groups
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_COL
)
==
0
:
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2P5D_ROW
,
dims
=
{
weight_key
:
-
1
},
partition_states
=
{
weight_key
:
True
},
)
# partition in column groups
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2P5D_COL
,
dims
=
{
weight_key
:
0
},
partition_states
=
{
weight_key
:
True
},
)
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
,
**
kwargs
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
})
# gather in column groups
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2P5D_COL
,
dims
=
{
weight_key
:
0
},
partition_states
=
{
weight_key
:
True
},
keep_vars
=
keep_vars
,
)
# gather in row groups
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_COL
)
==
0
:
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2P5D_ROW
,
dims
=
{
weight_key
:
-
1
},
partition_states
=
{
weight_key
:
True
},
keep_vars
=
keep_vars
,
)
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
# Build the mask.
input_mask
=
(
input_
<
self
.
vocab_start_index
)
|
(
input_
>=
self
.
vocab_end_index
)
...
...
@@ -491,6 +886,7 @@ class Classifier2p5D(ParallelLayer):
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def
__init__
(
self
,
in_features
:
int
,
num_classes
:
int
,
...
...
@@ -544,6 +940,93 @@ class Classifier2p5D(ParallelLayer):
broadcast
(
self
.
bias
,
col_src_rank
,
ParallelMode
.
PARALLEL_2P5D_COL
)
broadcast
(
self
.
bias
,
row_src_rank
,
ParallelMode
.
PARALLEL_2P5D_ROW
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
if
self
.
has_weight
:
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
# bias
if
self
.
bias
is
not
None
:
bias
=
state_dict
.
pop
(
bias_key
,
None
)
if
bias
is
not
None
:
local_state
[
bias_key
]
=
bias
# partition in row groups
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_COL
)
==
0
:
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2P5D_ROW
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
False
},
)
# partition in column groups
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2P5D_COL
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
False
},
)
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
,
**
kwargs
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
local_state
=
OrderedDict
()
if
self
.
has_weight
:
local_state
[
weight_key
]
=
self
.
weight
if
self
.
bias
is
not
None
:
local_state
[
bias_key
]
=
self
.
bias
# gather in column groups
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2P5D_COL
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
False
},
keep_vars
=
keep_vars
,
)
# gather in row groups
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_COL
)
==
0
:
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2P5D_ROW
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
False
},
keep_vars
=
keep_vars
,
)
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
out_shape
=
input_
.
shape
[:
-
1
]
+
(
self
.
num_classes
,
)
...
...
@@ -571,6 +1054,7 @@ class VocabParallelClassifier2p5D(ParallelLayer):
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def
__init__
(
self
,
in_features
:
int
,
num_classes
:
int
,
...
...
@@ -629,6 +1113,52 @@ class VocabParallelClassifier2p5D(ParallelLayer):
if
self
.
bias
is
not
None
:
bias_initializer
(
self
.
bias
,
fan_in
=
fan_in
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
if
self
.
has_weight
:
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
# bias
if
self
.
bias
is
not
None
:
bias
=
state_dict
.
pop
(
bias_key
,
None
)
if
bias
is
not
None
:
local_state
[
bias_key
]
=
bias
# partition in row groups
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_COL
)
==
0
:
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2P5D_ROW
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
},
)
# partition in column groups
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2P5D_COL
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
},
)
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
,
**
kwargs
)
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
# input: [m/dq, n/q, k/q]
# output: [m/dq, n/q, h/q]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment