Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
7636d518
Unverified
Commit
7636d518
authored
Apr 01, 2022
by
アマデウス
Committed by
GitHub
Apr 01, 2022
Browse files
[model checkpoint] updated saving/loading for 2d layers (#595)
parent
cd13b638
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
575 additions
and
8 deletions
+575
-8
colossalai/nn/layer/parallel_2d/layers.py
colossalai/nn/layer/parallel_2d/layers.py
+575
-8
No files found.
colossalai/nn/layer/parallel_2d/layers.py
View file @
7636d518
import
math
from
collections
import
OrderedDict
from
typing
import
Callable
import
torch
...
...
@@ -10,13 +11,15 @@ from colossalai.core import global_context as gpc
from
colossalai.global_variables
import
tensor_parallel_env
as
env
from
colossalai.nn
import
init
as
init
from
colossalai.registry
import
LAYERS
from
colossalai.utils.checkpointing
import
gather_tensor_parallel_state_dict
,
partition_tensor_parallel_state_dict
from
colossalai.utils.cuda
import
get_current_device
from
torch
import
Tensor
from
torch.nn
import
Parameter
from
..base_layer
import
ParallelLayer
from
..utils
import
divide
,
set_tensor_parallel_attribute_by_partition
,
to_2tuple
from
._operation
import
*
from
._operation
import
(
Matmul_AB_2D
,
Matmul_ABT_2D
,
add_bias_2d
,
all_gather_tensor_2d
,
classifier_2d
,
layernorm_2d
,
reduce_scatter_tensor_2d
,
split_tensor_2d
)
from
._utils
import
assert_summa_initialization
,
get_summa_dim_from_env
...
...
@@ -39,6 +42,7 @@ class Linear2D(ParallelLayer):
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
...
...
@@ -90,6 +94,91 @@ class Linear2D(ParallelLayer):
if
self
.
bias
is
not
None
:
bias_initializer
(
self
.
bias
,
fan_in
=
fan_in
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
.
transpose
(
0
,
1
)
# bias
if
self
.
bias
is
not
None
:
bias
=
state_dict
.
pop
(
bias_key
,
None
)
if
bias
is
not
None
:
local_state
[
bias_key
]
=
bias
# partition in row groups
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_COL
)
==
0
:
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2D_ROW
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
},
)
# partition in column groups
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2D_COL
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
},
)
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
,
**
kwargs
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
})
if
self
.
bias
is
not
None
:
local_state
[
bias_key
]
=
self
.
bias
# gather in column groups
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2D_COL
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
},
keep_vars
=
keep_vars
,
)
# gather in row groups
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_COL
)
==
0
:
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2D_ROW
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
},
keep_vars
=
keep_vars
,
)
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
local_state
[
weight_key
]
=
local_state
[
weight_key
].
transpose
(
0
,
1
)
destination
.
update
(
local_state
)
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
# input: [m/q, n/q, k/q]
# output: [m/q, n/q, h/q]
...
...
@@ -129,6 +218,7 @@ class LayerNorm2D(ParallelLayer):
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
def
__init__
(
self
,
normalized_shape
:
int
,
eps
:
float
=
1e-05
,
dtype
=
None
):
super
().
__init__
()
...
...
@@ -148,14 +238,95 @@ class LayerNorm2D(ParallelLayer):
# create parameters
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
self
.
gamma
=
Parameter
(
torch
.
ones
(
self
.
partitioned_partition
,
**
factory_kwargs
))
self
.
b
eta
=
Parameter
(
torch
.
zeros
(
self
.
partitioned_partition
,
**
factory_kwargs
))
self
.
weight
=
Parameter
(
torch
.
ones
(
self
.
partitioned_partition
,
**
factory_kwargs
))
self
.
b
ias
=
Parameter
(
torch
.
zeros
(
self
.
partitioned_partition
,
**
factory_kwargs
))
self
.
_set_tensor_parallel_attributes
()
def
_set_tensor_parallel_attributes
(
self
):
set_tensor_parallel_attribute_by_partition
(
self
.
gamma
,
self
.
summa_dim
**
2
)
set_tensor_parallel_attribute_by_partition
(
self
.
beta
,
self
.
summa_dim
**
2
)
set_tensor_parallel_attribute_by_partition
(
self
.
weight
,
self
.
summa_dim
**
2
)
set_tensor_parallel_attribute_by_partition
(
self
.
bias
,
self
.
summa_dim
**
2
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
# bias
bias
=
state_dict
.
pop
(
bias_key
,
None
)
if
bias
is
not
None
:
local_state
[
bias_key
]
=
bias
# partition in row groups
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_COL
)
==
0
:
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2D_ROW
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
},
)
# partition in column groups
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2D_COL
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
},
)
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
,
**
kwargs
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
,
bias_key
:
self
.
bias
})
# gather in column groups
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2D_COL
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
},
keep_vars
=
keep_vars
,
)
# gather in row groups
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_COL
)
==
0
:
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2D_ROW
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
},
keep_vars
=
keep_vars
,
)
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
destination
.
update
(
local_state
)
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
with
torch
.
no_grad
():
...
...
@@ -174,10 +345,10 @@ class LayerNorm2D(ParallelLayer):
output
=
layernorm_2d
(
x
,
E_x
,
Var_x
,
self
.
normalized_shape
,
ParallelMode
.
PARALLEL_2D_ROW
,
ParallelMode
.
PARALLEL_2D_COL
)
bias
=
add_bias_2d
(
None
,
self
.
b
eta
,
self
.
partitioned_partition
,
self
.
row_rank
,
self
.
col_rank
,
bias
=
add_bias_2d
(
None
,
self
.
b
ias
,
self
.
partitioned_partition
,
self
.
row_rank
,
self
.
col_rank
,
ParallelMode
.
PARALLEL_2D_ROW
,
ParallelMode
.
PARALLEL_2D_COL
,
True
,
self
.
data_parallel_rank
,
self
.
pipeline_parallel_rank
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
)
scale
=
add_bias_2d
(
None
,
self
.
gamma
,
self
.
partitioned_partition
,
self
.
row_rank
,
self
.
col_rank
,
scale
=
add_bias_2d
(
None
,
self
.
weight
,
self
.
partitioned_partition
,
self
.
row_rank
,
self
.
col_rank
,
ParallelMode
.
PARALLEL_2D_ROW
,
ParallelMode
.
PARALLEL_2D_COL
,
True
,
self
.
data_parallel_rank
,
self
.
pipeline_parallel_rank
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
)
output
=
torch
.
addcmul
(
bias
,
scale
,
output
)
...
...
@@ -205,6 +376,7 @@ class PatchEmbedding2D(ParallelLayer):
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def
__init__
(
self
,
img_size
:
int
,
patch_size
:
int
,
...
...
@@ -260,6 +432,120 @@ class PatchEmbedding2D(ParallelLayer):
bias_initializer
(
self
.
bias
,
fan_in
=
fan_in
)
position_embed_initializer
(
self
.
pos_embed
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
cls_token_key
=
prefix
+
'cls_token'
pos_embed_key
=
prefix
+
'pos_embed'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
# bias
bias
=
state_dict
.
pop
(
bias_key
,
None
)
if
bias
is
not
None
:
local_state
[
bias_key
]
=
bias
# cls token
cls_token
=
state_dict
.
pop
(
cls_token_key
,
None
)
if
cls_token
is
not
None
:
local_state
[
cls_token_key
]
=
cls_token
# pos embed
pos_embed
=
state_dict
.
pop
(
pos_embed_key
,
None
)
if
pos_embed
is
not
None
:
local_state
[
pos_embed_key
]
=
pos_embed
# partition in row groups
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_COL
)
==
0
:
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2D_ROW
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
,
cls_token_key
:
-
1
,
pos_embed_key
:
-
1
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
,
cls_token_key
:
True
,
pos_embed_key
:
True
},
)
# partition in column groups
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2D_COL
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
,
cls_token_key
:
-
1
,
pos_embed_key
:
-
1
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
,
cls_token_key
:
True
,
pos_embed_key
:
True
},
)
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
,
**
kwargs
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
cls_token_key
=
prefix
+
'cls_token'
pos_embed_key
=
prefix
+
'pos_embed'
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
,
bias_key
:
self
.
bias
,
cls_token_key
:
self
.
cls_token
,
pos_embed_key
:
self
.
pos_embed
})
# gather in column groups
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2D_COL
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
,
cls_token_key
:
-
1
,
pos_embed_key
:
-
1
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
,
cls_token_key
:
True
,
pos_embed_key
:
True
},
keep_vars
=
keep_vars
,
)
# gather in row groups
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_COL
)
==
0
:
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2D_ROW
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
,
cls_token_key
:
-
1
,
pos_embed_key
:
-
1
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
,
cls_token_key
:
True
,
pos_embed_key
:
True
},
keep_vars
=
keep_vars
,
)
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
input_
=
split_tensor_2d
(
input_
)
...
...
@@ -313,6 +599,7 @@ class Embedding2D(ParallelLayer):
More details about initializer please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_
"""
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
...
...
@@ -353,6 +640,57 @@ class Embedding2D(ParallelLayer):
with
torch
.
no_grad
():
self
.
weight
[
self
.
padding_idx
].
fill_
(
0
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
# partition in row groups
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_COL
)
==
0
:
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2D_ROW
,
dims
=
{
weight_key
:
-
1
},
partition_states
=
{
weight_key
:
True
},
)
# partition in column groups
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2D_COL
,
dims
=
{
weight_key
:
-
1
},
partition_states
=
{
weight_key
:
True
},
)
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
,
**
kwargs
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
})
# gather in column groups
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2D_COL
,
dims
=
{
weight_key
:
-
1
},
partition_states
=
{
weight_key
:
True
},
keep_vars
=
keep_vars
,
)
# gather in row groups
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_COL
)
==
0
:
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2D_ROW
,
dims
=
{
weight_key
:
-
1
},
partition_states
=
{
weight_key
:
True
},
keep_vars
=
keep_vars
,
)
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
input_
=
split_tensor_2d
(
input_
)
...
...
@@ -392,6 +730,7 @@ class VocabParallelEmbedding2D(torch.nn.Module):
More details about initializer please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
...
...
@@ -439,6 +778,57 @@ class VocabParallelEmbedding2D(torch.nn.Module):
with
torch
.
no_grad
():
self
.
weight
[
self
.
padding_idx
-
self
.
vocab_start_index
].
fill_
(
0
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
# partition in row groups
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_COL
)
==
0
:
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2D_ROW
,
dims
=
{
weight_key
:
-
1
},
partition_states
=
{
weight_key
:
True
},
)
# partition in column groups
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2D_COL
,
dims
=
{
weight_key
:
0
},
partition_states
=
{
weight_key
:
True
},
)
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
,
**
kwargs
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
})
# gather in column groups
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2D_COL
,
dims
=
{
weight_key
:
0
},
partition_states
=
{
weight_key
:
True
},
keep_vars
=
keep_vars
,
)
# gather in row groups
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_COL
)
==
0
:
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2D_ROW
,
dims
=
{
weight_key
:
-
1
},
partition_states
=
{
weight_key
:
True
},
keep_vars
=
keep_vars
,
)
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
input_mask
=
(
input_
<
self
.
vocab_start_index
)
|
(
input_
>=
self
.
vocab_end_index
)
masked_input
=
input_
.
clone
()
-
self
.
vocab_start_index
...
...
@@ -470,6 +860,7 @@ class Classifier2D(ParallelLayer):
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def
__init__
(
self
,
in_features
:
int
,
num_classes
:
int
,
...
...
@@ -522,6 +913,93 @@ class Classifier2D(ParallelLayer):
broadcast
(
self
.
bias
,
col_src_rank
,
ParallelMode
.
PARALLEL_2D_COL
)
broadcast
(
self
.
bias
,
row_src_rank
,
ParallelMode
.
PARALLEL_2D_ROW
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
if
self
.
has_weight
:
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
# bias
if
self
.
bias
is
not
None
:
bias
=
state_dict
.
pop
(
bias_key
,
None
)
if
bias
is
not
None
:
local_state
[
bias_key
]
=
bias
# partition in row groups
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_COL
)
==
0
:
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2D_ROW
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
False
},
)
# partition in column groups
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2D_COL
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
False
},
)
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
,
**
kwargs
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
local_state
=
OrderedDict
()
if
self
.
has_weight
:
local_state
[
weight_key
]
=
self
.
weight
if
self
.
bias
is
not
None
:
local_state
[
bias_key
]
=
self
.
bias
# gather in column groups
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2D_COL
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
False
},
keep_vars
=
keep_vars
,
)
# gather in row groups
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_COL
)
==
0
:
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2D_ROW
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
False
},
keep_vars
=
keep_vars
,
)
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
out_shape
=
input_
.
shape
[:
-
1
]
+
(
self
.
num_classes
,
)
...
...
@@ -548,6 +1026,7 @@ class VocabParallelClassifier2D(ParallelLayer):
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def
__init__
(
self
,
in_features
:
int
,
num_classes
:
int
,
...
...
@@ -605,6 +1084,94 @@ class VocabParallelClassifier2D(ParallelLayer):
if
self
.
bias
is
not
None
:
bias_initializer
(
self
.
bias
,
fan_in
=
fan_in
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
if
self
.
has_weight
:
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
# bias
if
self
.
bias
is
not
None
:
bias
=
state_dict
.
pop
(
bias_key
,
None
)
if
bias
is
not
None
:
local_state
[
bias_key
]
=
bias
# partition in row groups
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_COL
)
==
0
:
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2D_ROW
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
},
)
# partition in column groups
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2D_COL
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
},
)
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
,
**
kwargs
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
local_state
=
OrderedDict
()
if
self
.
has_weight
:
local_state
[
weight_key
]
=
self
.
weight
if
self
.
bias
is
not
None
:
local_state
[
bias_key
]
=
self
.
bias
# gather in column groups
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2D_COL
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
},
keep_vars
=
keep_vars
,
)
# gather in row groups
if
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_COL
)
==
0
:
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_2D_ROW
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
},
keep_vars
=
keep_vars
,
)
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
local_state
[
weight_key
]
=
local_state
[
weight_key
].
transpose
(
0
,
1
)
destination
.
update
(
local_state
)
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
# input: [m/q, n/q, k/q]
# output: [m/q, n/q, h/q]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment