Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
77ad24bf
Unverified
Commit
77ad24bf
authored
Apr 01, 2022
by
アマデウス
Committed by
GitHub
Apr 01, 2022
Browse files
[model checkpoint] updated saving/loading for 3d layers (#597)
parent
93089ed7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
565 additions
and
3 deletions
+565
-3
colossalai/nn/layer/parallel_3d/layers.py
colossalai/nn/layer/parallel_3d/layers.py
+565
-3
No files found.
colossalai/nn/layer/parallel_3d/layers.py
View file @
77ad24bf
import
math
from
collections
import
OrderedDict
from
typing
import
Callable
import
torch
...
...
@@ -12,13 +13,15 @@ from colossalai.global_variables import tensor_parallel_env as env
from
colossalai.nn
import
init
as
init
from
colossalai.nn.layer.base_layer
import
ParallelLayer
from
colossalai.registry
import
LAYERS
from
colossalai.utils.checkpointing
import
(
broadcast_state_dict
,
gather_tensor_parallel_state_dict
,
partition_tensor_parallel_state_dict
)
from
colossalai.utils.cuda
import
get_current_device
from
torch
import
Tensor
from
torch.nn
import
Parameter
from
..utils
import
divide
,
set_tensor_parallel_attribute_by_partition
,
to_2tuple
from
._operation
import
layern
or
m
_3d
,
linear_3d
,
classifier_3d
,
split_tens
or_3d
from
._operation
import
all_gather_tensor_3d
,
reduce_scatter_tensor_3d
,
broadcast_weight_3d_from_diagonal
from
._operation
import
(
all_gather_tens
or_3d
,
broadcast_weight_3d_from_diagonal
,
classifier_3d
,
layern
or
m
_3d
,
linear_3d
,
reduce_scatter_tensor_3d
,
split_tensor_3d
)
from
._utils
import
get_depth_from_env
,
get_last_group
,
get_parallel_mode_from_env
,
swap_in_out_group
...
...
@@ -61,6 +64,67 @@ class LayerNorm3D(ParallelLayer):
init
.
zeros_
()(
self
.
bias
)
init
.
ones_
()(
self
.
weight
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
.
transpose
(
0
,
1
)
# bias
bias
=
state_dict
.
pop
(
bias_key
,
None
)
if
bias
is
not
None
:
local_state
[
bias_key
]
=
bias
# partition in output groups
if
gpc
.
get_local_rank
(
self
.
input_parallel_mode
)
==
0
and
\
gpc
.
get_local_rank
(
self
.
weight_parallel_mode
)
==
0
:
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
self
.
output_parallel_mode
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
,
},
)
# broadcast in input groups
if
gpc
.
get_local_rank
(
self
.
weight_parallel_mode
)
==
0
:
local_state
=
broadcast_state_dict
(
local_state
,
self
.
input_parallel_mode
)
# broadcast in weight groups
local_state
=
broadcast_state_dict
(
local_state
,
self
.
weight_parallel_mode
)
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
,
**
kwargs
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
,
bias_key
:
self
.
bias
})
# gather in output groups
if
gpc
.
get_local_rank
(
self
.
input_parallel_mode
)
==
0
and
\
gpc
.
get_local_rank
(
self
.
weight_parallel_mode
)
==
0
:
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
self
.
output_parallel_mode
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
},
keep_vars
=
keep_vars
,
)
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
return
layernorm_3d
(
input_
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
variance_epsilon
,
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
,
self
.
output_parallel_mode
)
...
...
@@ -135,6 +199,122 @@ class Linear3D(ParallelLayer):
broadcast
(
self
.
bias
,
weight_src_rank
,
self
.
weight_parallel_mode
)
broadcast
(
self
.
bias
,
output_src_rank
,
self
.
output_parallel_mode
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
.
transpose
(
0
,
1
)
# bias
if
self
.
bias
is
not
None
:
bias
=
state_dict
.
pop
(
bias_key
,
None
)
if
bias
is
not
None
:
local_state
[
bias_key
]
=
bias
# partition in output groups
if
gpc
.
get_local_rank
(
self
.
input_parallel_mode
)
==
0
and
\
gpc
.
get_local_rank
(
self
.
weight_parallel_mode
)
==
0
:
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
self
.
output_parallel_mode
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
False
},
)
# partition in input groups
if
gpc
.
get_local_rank
(
self
.
weight_parallel_mode
)
==
0
:
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
self
.
input_parallel_mode
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
},
)
# partition in weight groups
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
self
.
weight_parallel_mode
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
False
},
)
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
,
**
kwargs
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
})
if
self
.
bias
is
not
None
:
local_state
[
bias_key
]
=
self
.
bias
# gather in weight groups
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
self
.
weight_parallel_mode
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
False
},
keep_vars
=
keep_vars
,
)
# gather in input groups
if
gpc
.
get_local_rank
(
self
.
weight_parallel_mode
)
==
0
:
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
self
.
input_parallel_mode
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
},
keep_vars
=
keep_vars
,
)
# gather in output groups
if
gpc
.
get_local_rank
(
self
.
input_parallel_mode
)
==
0
and
\
gpc
.
get_local_rank
(
self
.
weight_parallel_mode
)
==
0
:
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
self
.
output_parallel_mode
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
False
},
keep_vars
=
keep_vars
,
)
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
local_state
[
weight_key
]
=
local_state
[
weight_key
].
transpose
(
0
,
1
)
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
return
linear_3d
(
input_
,
self
.
weight
,
self
.
bias
,
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
,
self
.
output_parallel_mode
)
...
...
@@ -212,6 +392,73 @@ class Classifier3D(ParallelLayer):
broadcast
(
self
.
bias
,
output_src_rank
,
self
.
output_parallel_mode
)
broadcast
(
self
.
bias
,
input_src_rank
,
self
.
input_parallel_mode
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
if
self
.
has_weight
:
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
# bias
if
self
.
bias
is
not
None
:
bias
=
state_dict
.
pop
(
bias_key
,
None
)
if
bias
is
not
None
:
local_state
[
bias_key
]
=
bias
# partition in output groups
if
gpc
.
get_local_rank
(
self
.
input_parallel_mode
)
==
0
and
\
gpc
.
get_local_rank
(
self
.
weight_parallel_mode
)
==
0
:
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
self
.
output_parallel_mode
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
False
},
)
# broadcast in input groups
if
gpc
.
get_local_rank
(
self
.
weight_parallel_mode
)
==
0
:
local_state
=
broadcast_state_dict
(
local_state
,
self
.
input_parallel_mode
)
# broadcast in weight groups
local_state
=
broadcast_state_dict
(
local_state
,
self
.
weight_parallel_mode
)
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
,
**
kwargs
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
local_state
=
OrderedDict
()
if
self
.
has_weight
:
local_state
[
weight_key
]
=
self
.
weight
if
self
.
bias
is
not
None
:
local_state
[
bias_key
]
=
self
.
bias
# gather in output groups
if
gpc
.
get_local_rank
(
self
.
input_parallel_mode
)
==
0
and
\
gpc
.
get_local_rank
(
self
.
weight_parallel_mode
)
==
0
:
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
self
.
output_parallel_mode
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
False
},
keep_vars
=
keep_vars
,
)
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
return
classifier_3d
(
input_
,
self
.
weight
,
self
.
bias
,
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
,
self
.
output_parallel_mode
)
...
...
@@ -296,6 +543,122 @@ class VocabParallelClassifier3D(ParallelLayer):
broadcast
(
self
.
bias
,
weight_src_rank
,
self
.
weight_parallel_mode
)
broadcast
(
self
.
bias
,
output_src_rank
,
self
.
output_parallel_mode
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
if
self
.
has_weight
:
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
# bias
if
self
.
bias
is
not
None
:
bias
=
state_dict
.
pop
(
bias_key
,
None
)
if
bias
is
not
None
:
local_state
[
bias_key
]
=
bias
# partition in output groups
if
gpc
.
get_local_rank
(
self
.
input_parallel_mode
)
==
0
and
\
gpc
.
get_local_rank
(
self
.
weight_parallel_mode
)
==
0
:
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
self
.
output_parallel_mode
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
False
},
)
# partition in input groups
if
gpc
.
get_local_rank
(
self
.
weight_parallel_mode
)
==
0
:
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
self
.
input_parallel_mode
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
},
)
# partition in weight groups
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
self
.
weight_parallel_mode
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
False
},
)
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
,
**
kwargs
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
})
if
self
.
bias
is
not
None
:
local_state
[
bias_key
]
=
self
.
bias
# gather in weight groups
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
self
.
weight_parallel_mode
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
False
},
keep_vars
=
keep_vars
,
)
# gather in input groups
if
gpc
.
get_local_rank
(
self
.
weight_parallel_mode
)
==
0
:
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
self
.
input_parallel_mode
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
},
keep_vars
=
keep_vars
,
)
# gather in output groups
if
gpc
.
get_local_rank
(
self
.
input_parallel_mode
)
==
0
and
\
gpc
.
get_local_rank
(
self
.
weight_parallel_mode
)
==
0
:
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
self
.
output_parallel_mode
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
False
},
keep_vars
=
keep_vars
,
)
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
return
linear_3d
(
input_
,
self
.
weight
.
transpose
(
0
,
1
),
self
.
bias
,
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
,
self
.
output_parallel_mode
)
...
...
@@ -392,12 +755,98 @@ class PatchEmbedding3D(ParallelLayer):
self
.
cls_token
.
register_hook
(
self
.
_sync_grad_hook
)
self
.
pos_embed
.
register_hook
(
self
.
_sync_grad_hook
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
cls_token_key
=
prefix
+
'cls_token'
pos_embed_key
=
prefix
+
'pos_embed'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
# bias
bias
=
state_dict
.
pop
(
bias_key
,
None
)
if
bias
is
not
None
:
local_state
[
bias_key
]
=
bias
# cls token
cls_token
=
state_dict
.
pop
(
cls_token_key
,
None
)
if
cls_token
is
not
None
:
local_state
[
cls_token_key
]
=
cls_token
# pos embed
pos_embed
=
state_dict
.
pop
(
pos_embed_key
,
None
)
if
pos_embed
is
not
None
:
local_state
[
pos_embed_key
]
=
pos_embed
# partition in output groups
if
gpc
.
get_local_rank
(
self
.
input_parallel_mode
)
==
0
and
\
gpc
.
get_local_rank
(
self
.
weight_parallel_mode
)
==
0
:
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
self
.
output_parallel_mode
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
,
cls_token_key
:
-
1
,
pos_embed_key
:
-
1
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
,
cls_token_key
:
True
,
pos_embed_key
:
True
},
)
# broadcast in input groups
if
gpc
.
get_local_rank
(
self
.
weight_parallel_mode
)
==
0
:
local_state
=
broadcast_state_dict
(
local_state
,
self
.
input_parallel_mode
)
# broadcast in weight groups
local_state
=
broadcast_state_dict
(
local_state
,
self
.
weight_parallel_mode
)
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
,
**
kwargs
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
cls_token_key
=
prefix
+
'cls_token'
pos_embed_key
=
prefix
+
'pos_embed'
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
,
bias_key
:
self
.
bias
,
cls_token_key
:
self
.
cls_token
,
pos_embed_key
:
self
.
pos_embed
})
# gather in output groups
if
gpc
.
get_local_rank
(
self
.
input_parallel_mode
)
==
0
and
\
gpc
.
get_local_rank
(
self
.
weight_parallel_mode
)
==
0
:
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
self
.
output_parallel_mode
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
,
cls_token_key
:
-
1
,
pos_embed_key
:
-
1
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
,
cls_token_key
:
True
,
pos_embed_key
:
True
},
keep_vars
=
keep_vars
,
)
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
input_
=
split_tensor_3d
(
input_
,
0
,
self
.
weight_parallel_mode
)
input_
=
split_tensor_3d
(
input_
,
0
,
self
.
input_parallel_mode
)
output
=
F
.
conv2d
(
input_
,
self
.
weight
,
self
.
bias
,
stride
=
self
.
patch_size
)
if
self
.
flatten
:
output
=
output
.
flatten
(
2
).
transpose
(
1
,
2
)
# BCHW -> BNC
output
=
output
.
flatten
(
2
).
transpose
(
1
,
2
)
# BCHW -> BNC
cls_token
=
self
.
cls_token
.
expand
(
output
.
shape
[
0
],
-
1
,
-
1
)
output
=
torch
.
cat
((
cls_token
,
output
),
dim
=
1
)
...
...
@@ -480,6 +929,49 @@ class Embedding3D(ParallelLayer):
with
torch
.
no_grad
():
self
.
weight
[
self
.
padding_idx
].
fill_
(
0
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
# partition in output groups
if
gpc
.
get_local_rank
(
self
.
input_parallel_mode
)
==
0
and
\
gpc
.
get_local_rank
(
self
.
weight_parallel_mode
)
==
0
:
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
self
.
output_parallel_mode
,
dims
=
{
weight_key
:
0
},
partition_states
=
{
weight_key
:
True
},
)
# broadcast in input groups
if
gpc
.
get_local_rank
(
self
.
weight_parallel_mode
)
==
0
:
local_state
=
broadcast_state_dict
(
local_state
,
self
.
input_parallel_mode
)
# broadcast in weight groups
local_state
=
broadcast_state_dict
(
local_state
,
self
.
weight_parallel_mode
)
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
,
**
kwargs
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
})
# gather in output groups
if
gpc
.
get_local_rank
(
self
.
input_parallel_mode
)
==
0
and
\
gpc
.
get_local_rank
(
self
.
weight_parallel_mode
)
==
0
:
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
self
.
output_parallel_mode
,
dims
=
{
weight_key
:
0
},
partition_states
=
{
weight_key
:
True
},
keep_vars
=
keep_vars
,
)
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
input_
=
split_tensor_3d
(
input_
,
0
,
self
.
weight_parallel_mode
)
input_
=
split_tensor_3d
(
input_
,
0
,
self
.
input_parallel_mode
)
...
...
@@ -570,6 +1062,76 @@ class VocabParallelEmbedding3D(torch.nn.Module):
with
torch
.
no_grad
():
self
.
weight
[
self
.
padding_idx
-
self
.
vocab_start_index
].
fill_
(
0
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
# partition in output groups
if
gpc
.
get_local_rank
(
self
.
input_parallel_mode
)
==
0
and
\
gpc
.
get_local_rank
(
self
.
weight_parallel_mode
)
==
0
:
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
self
.
output_parallel_mode
,
dims
=
{
weight_key
:
-
1
},
partition_states
=
{
weight_key
:
True
},
)
# partition in input groups
if
gpc
.
get_local_rank
(
self
.
weight_parallel_mode
)
==
0
:
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
self
.
input_parallel_mode
,
dims
=
{
weight_key
:
0
},
partition_states
=
{
weight_key
:
True
},
)
# partition in weight groups
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
self
.
weight_parallel_mode
,
dims
=
{
weight_key
:
0
},
partition_states
=
{
weight_key
:
True
},
)
super
().
_load_from_state_dict
(
local_state
,
prefix
,
*
args
,
**
kwargs
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
})
# gather in weight groups
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
self
.
weight_parallel_mode
,
dims
=
{
weight_key
:
0
},
partition_states
=
{
weight_key
:
True
},
keep_vars
=
keep_vars
,
)
# gather in input groups
if
gpc
.
get_local_rank
(
self
.
weight_parallel_mode
)
==
0
:
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
self
.
input_parallel_mode
,
dims
=
{
weight_key
:
0
},
partition_states
=
{
weight_key
:
True
},
keep_vars
=
keep_vars
,
)
# gather in output groups
if
gpc
.
get_local_rank
(
self
.
input_parallel_mode
)
==
0
and
\
gpc
.
get_local_rank
(
self
.
weight_parallel_mode
)
==
0
:
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
self
.
output_parallel_mode
,
dims
=
{
weight_key
:
-
1
},
partition_states
=
{
weight_key
:
True
},
keep_vars
=
keep_vars
,
)
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
input_
=
split_tensor_3d
(
input_
,
0
,
self
.
weight_parallel_mode
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment