Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
015af592
Commit
015af592
authored
Jun 15, 2023
by
Frank Lee
Browse files
[shardformer] integrated linear 1D with dtensor (#3996)
* [shardformer] integrated linear 1D with dtensor * polish code
parent
d3bc5308
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
706 additions
and
407 deletions
+706
-407
colossalai/nn/layer/base_layer.py
colossalai/nn/layer/base_layer.py
+1
-0
colossalai/shardformer/layer/_operation.py
colossalai/shardformer/layer/_operation.py
+126
-7
colossalai/shardformer/layer/dropout.py
colossalai/shardformer/layer/dropout.py
+8
-46
colossalai/shardformer/layer/layers.py
colossalai/shardformer/layer/layers.py
+310
-343
colossalai/shardformer/layer/utils.py
colossalai/shardformer/layer/utils.py
+138
-0
colossalai/tensor/d_tensor/api.py
colossalai/tensor/d_tensor/api.py
+44
-0
colossalai/tensor/d_tensor/layout.py
colossalai/tensor/d_tensor/layout.py
+11
-10
colossalai/tensor/d_tensor/layout_converter.py
colossalai/tensor/d_tensor/layout_converter.py
+1
-1
tests/test_shardformer/test_layer/test_linear_1d.py
tests/test_shardformer/test_layer/test_linear_1d.py
+67
-0
No files found.
colossalai/nn/layer/base_layer.py
View file @
015af592
...
@@ -10,6 +10,7 @@ from colossalai.core import global_context as gpc
...
@@ -10,6 +10,7 @@ from colossalai.core import global_context as gpc
class
ParallelLayer
(
nn
.
Module
):
class
ParallelLayer
(
nn
.
Module
):
global_state_dict
:
bool
=
True
global_state_dict
:
bool
=
True
def
__init__
(
self
):
def
__init__
(
self
):
...
...
colossalai/shardformer/layer/_operation.py
View file @
015af592
...
@@ -54,10 +54,10 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
...
@@ -54,10 +54,10 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
"""
"""
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input_
,
weight
,
bias
,
p
arallel_mode
,
async_grad_allreduce
):
def
forward
(
ctx
,
input_
,
weight
,
bias
,
p
rocess_group
,
async_grad_allreduce
):
ctx
.
save_for_backward
(
input_
,
weight
)
ctx
.
save_for_backward
(
input_
,
weight
)
ctx
.
use_bias
=
bias
is
not
None
ctx
.
use_bias
=
bias
is
not
None
ctx
.
p
arallel_mode
=
parallel_mode
ctx
.
p
rocess_group
=
process_group
ctx
.
async_grad_allreduce
=
async_grad_allreduce
ctx
.
async_grad_allreduce
=
async_grad_allreduce
output
=
torch
.
matmul
(
input_
,
weight
.
t
())
output
=
torch
.
matmul
(
input_
,
weight
.
t
())
...
@@ -74,12 +74,13 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
...
@@ -74,12 +74,13 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
grad_input
=
grad_output
.
matmul
(
weight
)
grad_input
=
grad_output
.
matmul
(
weight
)
grad_output
=
grad_output
.
contiguous
()
grad_output
=
grad_output
.
contiguous
()
# Convert the tensor shapes to 2D for execution compatibility
# Convert the tensor shapes to 2D for execution compatibility
grad_output
=
grad_output
.
view
(
grad_output
.
shape
[
0
]
*
grad_output
.
shape
[
1
],
grad_output
.
shape
[
2
])
if
len
(
grad_output
.
shape
)
>
2
:
total_input
=
total_input
.
view
(
total_input
.
shape
[
0
]
*
total_input
.
shape
[
1
],
total_input
.
shape
[
2
])
grad_output
=
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
])
total_input
=
total_input
.
view
(
-
1
,
total_input
.
shape
[
-
1
])
if
ctx
.
async_grad_allreduce
:
if
ctx
.
async_grad_allreduce
:
# Asynchronous all-reduce
# Asynchronous all-reduce
handle
=
dist
.
all_reduce
(
grad_input
,
group
=
gpc
.
get_group
(
ctx
.
parallel_mode
)
,
async_op
=
True
)
handle
=
dist
.
all_reduce
(
grad_input
,
group
=
ctx
.
process_group
,
async_op
=
True
)
# Delay the start of weight gradient computation shortly (3us) to have
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
# all-reduce scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
...
@@ -93,5 +94,123 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
...
@@ -93,5 +94,123 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
def
linear_with_async_comm
(
input_
,
weight
,
bias
,
parallel_mode
,
async_grad_allreduce
):
class
_SplitForwardGatherBackward
(
torch
.
autograd
.
Function
):
return
LinearWithAsyncCommunication
.
apply
(
input_
,
weight
,
bias
,
parallel_mode
,
async_grad_allreduce
)
"""
Split the input and keep only the corresponding chuck to the rank.
Args:
input_ (`torch.Tensor`): input matrix.
dim (int): the dimension to perform split and gather
process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication
"""
@
staticmethod
def
forward
(
ctx
,
input_
,
dim
,
process_group
):
ctx
.
process_group
=
process_group
ctx
.
dim
=
dim
return
_split
(
input_
,
dim
,
process_group
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_gather
(
grad_output
,
ctx
.
dim
,
ctx
.
process_group
),
None
,
None
class
_ReduceInput
(
torch
.
autograd
.
Function
):
"""
All-reduce the input from the model parallel region.
Args:
input_: input matrix.
parallel_mode: parallel mode.
"""
@
staticmethod
def
forward
(
ctx
,
input_
,
process_group
):
return
_reduce
(
input_
,
process_group
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
grad_output
,
None
def
_reduce
(
input_
,
process_group
):
# skip if only one rank involved
if
dist
.
get_world_size
(
process_group
)
==
1
:
return
input_
else
:
dist
.
all_reduce
(
input_
,
group
=
process_group
)
return
input_
def
_split
(
input_
,
dim
=-
1
,
process_group
=
None
):
# skip if only one rank involved
world_size
=
dist
.
get_world_size
(
process_group
)
if
world_size
==
1
:
return
input_
# Split along last dimension.
dim_size
=
input_
.
size
(
dim
)
assert
dim_size
%
world_size
==
0
,
\
f
'The dimension to split (
{
dim_size
}
) is not a multiple of world size (
{
world_size
}
), '
\
f
'cannot split tensor evenly'
tensor_list
=
torch
.
split
(
input_
,
dim_size
//
world_size
,
dim
=
dim
)
rank
=
dist
.
get_rank
(
process_group
)
output
=
tensor_list
[
rank
].
contiguous
()
return
output
def
_gather
(
input_
,
dim
=-
1
,
process_group
=
None
):
# skip if only one rank involved
world_size
=
dist
.
get_world_size
(
process_group
)
if
world_size
==
1
:
return
input_
# all gather
rank
=
dist
.
get_rank
(
process_group
)
tensor_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
tensor_list
[
rank
]
=
input_
torch
.
distributed
.
all_gather
(
tensor_list
,
input_
,
group
=
process_group
)
# concat
output
=
torch
.
cat
(
tensor_list
,
dim
=
dim
).
contiguous
()
return
output
class
_GatherForwardSplitBackward
(
torch
.
autograd
.
Function
):
"""Gather the input from model parallel region and concatenate.
Args:
input_: input matrix.
parallel_mode: parallel mode.
dim: dimension
"""
@
staticmethod
def
forward
(
ctx
,
input_
,
dim
,
process_group
):
ctx
.
process_group
=
process_group
ctx
.
dim
=
dim
return
_gather
(
input_
,
dim
,
process_group
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_split
(
grad_output
,
ctx
.
dim
,
ctx
.
process_group
),
None
,
None
def
linear_with_async_comm
(
input_
,
weight
,
bias
,
process_group
,
async_grad_allreduce
):
return
LinearWithAsyncCommunication
.
apply
(
input_
,
weight
,
bias
,
process_group
,
async_grad_allreduce
)
def
gather_forward_split_backward
(
input_
,
dim
,
process_group
):
return
_GatherForwardSplitBackward
.
apply
(
input_
,
dim
,
process_group
)
def
split_forward_gather_backward
(
input_
,
dim
,
process_group
):
return
_SplitForwardGatherBackward
.
apply
(
input_
,
dim
,
process_group
)
def
reduce_input
(
input_
,
process_group
):
return
_ReduceInput
.
apply
(
input_
,
process_group
)
colossalai/shardformer/layer/dropout.py
View file @
015af592
import
os
from
contextlib
import
contextmanager
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
torch.nn
as
nn
from
.utils
import
create_randomizer_with_offset
class
SeedManager
:
"""
This class is a random state manager to change random state for different random seed.
"""
def
__init__
(
self
):
original_state
=
torch
.
cuda
.
get_rng_state
()
# TODO: unify this seed manager with the colossalai.context.random
seed
=
os
.
getpid
()
torch
.
cuda
.
manual_seed
(
int
(
seed
))
self
.
dropout_state
=
torch
.
cuda
.
get_rng_state
()
torch
.
cuda
.
set_rng_state
(
original_state
)
def
set_mode
(
self
,
rng_state
):
torch
.
cuda
.
set_rng_state
(
rng_state
)
def
get_current_mode
(
self
):
current_state
=
torch
.
cuda
.
get_rng_state
()
return
current_state
@
contextmanager
def
dropout_mode
(
self
):
"""
This is a context manager to change the dropout state and recover the original state.
Usage:
::
>>> with _seed_manager.dropout_mode():
>>> input = super().forward(input)
"""
try
:
current_mode
=
self
.
get_current_mode
()
yield
self
.
set_mode
(
self
.
dropout_state
)
finally
:
self
.
dropout_state
=
self
.
get_current_mode
()
self
.
set_mode
(
current_mode
)
_seed_manager
=
SeedManager
()
class
Dropout1D
(
nn
.
Dropout
):
class
Dropout1D
(
nn
.
Dropout
):
def
__init__
(
self
,
p
=
0.5
,
inplace
=
False
):
def
__init__
(
self
,
p
=
0.5
,
inplace
=
False
,
process_group
=
None
):
super
().
__init__
(
p
,
inplace
)
super
().
__init__
(
p
,
inplace
)
# offset the seed with randomizer index and rank
seed
=
torch
.
random
.
initial_seed
()
self
.
randomizer
=
create_randomizer_with_offset
(
seed
,
process_group
=
process_group
)
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
with
_
se
ed_manager
.
dropout_mode
():
with
se
lf
.
randomizer
.
fork_rng
():
input
=
super
().
forward
(
input
)
input
=
super
().
forward
(
input
)
return
input
return
input
colossalai/shardformer/layer/layers.py
View file @
015af592
...
@@ -2,12 +2,16 @@
...
@@ -2,12 +2,16 @@
# -*- encoding: utf-8 -*-
# -*- encoding: utf-8 -*-
import
math
import
math
from
abc
import
ABC
,
abstractmethod
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
Callable
,
Tuple
from
typing
import
Callable
,
List
,
Tuple
,
Union
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
Tensor
from
torch
import
Tensor
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
colossalai.communication
import
broadcast
from
colossalai.communication
import
broadcast
...
@@ -22,13 +26,11 @@ from colossalai.nn.layer.parallel_1d._utils import (
...
@@ -22,13 +26,11 @@ from colossalai.nn.layer.parallel_1d._utils import (
gather_forward_split_backward
,
gather_forward_split_backward
,
get_parallel_input
,
get_parallel_input
,
reduce_grad
,
reduce_grad
,
reduce_input
,
set_parallel_input
,
set_parallel_input
,
split_forward_gather_backward
,
)
)
from
colossalai.nn.layer.utils
import
divide
,
set_tensor_parallel_attribute_by_partition
from
colossalai.nn.layer.utils
import
divide
,
set_tensor_parallel_attribute_by_partition
from
colossalai.nn.layer.vanilla
import
VanillaLayerNorm
,
VanillaPatchEmbedding
from
colossalai.nn.layer.vanilla
import
VanillaLayerNorm
,
VanillaPatchEmbedding
from
colossalai.
registry
import
LAYERS
from
colossalai.
tensor.d_tensor.api
import
shard_colwise
,
shard_rowwise
from
colossalai.utils.checkpointing
import
(
from
colossalai.utils.checkpointing
import
(
broadcast_state_dict
,
broadcast_state_dict
,
gather_tensor_parallel_state_dict
,
gather_tensor_parallel_state_dict
,
...
@@ -36,7 +38,13 @@ from colossalai.utils.checkpointing import (
...
@@ -36,7 +38,13 @@ from colossalai.utils.checkpointing import (
)
)
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
._operation
import
linear_with_async_comm
from
._operation
import
(
gather_forward_split_backward
,
linear_with_async_comm
,
reduce_input
,
split_forward_gather_backward
,
)
from
.utils
import
create_randomizer_with_offset
Fast_LN
=
None
Fast_LN
=
None
try
:
try
:
...
@@ -46,21 +54,44 @@ except ImportError:
...
@@ -46,21 +54,44 @@ except ImportError:
pass
pass
# @LAYERS.register_module
class
ParallelModule
(
nn
.
Module
,
ABC
):
class
Linear1D
(
ColossalaiModule
):
r
"""Linear layer for 1D parallelism.
@
abstractmethod
def
from_native_module
(
module
:
nn
.
Module
,
process_group
:
Union
[
ProcessGroup
,
List
[
ProcessGroup
]]
=
None
)
->
"ParallelModule"
:
"""
Convert a native PyTorch module to a parallelized module.
Args:
module (nn.Module): the module to be converted.
process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication.
If this is a list, the process group at the ith index of the list will correspond to the process group
in the ith axis of the device mesh. Defaults to None, which means the global process group.
"""
pass
class
Linear1D_Col
(
ParallelModule
):
r
"""Linear layer with column parallelism.
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
its second dimension as :math:`A = [A_1, ..., A_p]`.
Args:
Args:
in_features (int): size of each input sample.
in_features (int): size of each input sample.
out_features (int): size of each output sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
gather_output (bool, optional): Whether to call all-gather on output, defaults to False.
device (`torch.device`): The device of parameters, defaults to None.
skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer,
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
which is preserved for kernel fusion, defaults to False
weight_initializer (
:class:
`typing.Callable`
, optional
):
weight_initializer (`typing.Callable`):
The initializer of weight, defaults to kaiming uniform initializer.
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (
:class:
`typing.Callable`
, optional
):
bias_initializer (`typing.Callable`):
The initializer of bias, defaults to xavier uniform initializer.
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
More details about ``initializer`` please refer to
...
@@ -72,32 +103,281 @@ class Linear1D(ColossalaiModule):
...
@@ -72,32 +103,281 @@ class Linear1D(ColossalaiModule):
out_features
:
int
,
out_features
:
int
,
bias
:
bool
=
True
,
bias
:
bool
=
True
,
dtype
:
torch
.
dtype
=
None
,
dtype
:
torch
.
dtype
=
None
,
device
:
torch
.
device
=
None
,
process_group
:
ProcessGroup
=
None
,
gather_output
:
bool
=
False
,
gather_output
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
)):
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
)):
parallel_input
=
get_parallel_input
()
super
().
__init__
()
if
not
parallel_input
and
not
gather_output
:
layer
=
Linear1D_Col
(
in_features
,
# Keep input parameters
out_features
,
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
gather_output
=
gather_output
self
.
skip_bias_add
=
skip_bias_add
self
.
device
=
device
self
.
process_group
=
process_group
self
.
num_partitions
=
dist
.
get_world_size
(
self
.
process_group
)
if
skip_bias_add
and
not
bias
:
raise
ValueError
(
'cannot skip bias addition if bias is None'
)
self
.
out_features_per_partition
=
divide
(
out_features
,
self
.
num_partitions
)
# Parameters.
# Initialize weight.
if
device
is
None
:
device
=
get_current_device
()
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
out_features_per_partition
,
self
.
in_features
,
**
factory_kwargs
))
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
out_features_per_partition
,
**
factory_kwargs
))
else
:
self
.
bias
=
None
# offset the seed with randomizer index and rank
seed
=
torch
.
random
.
initial_seed
()
self
.
randomizer
=
create_randomizer_with_offset
(
seed
,
process_group
=
self
.
process_group
)
with
self
.
randomizer
.
fork_rng
(
enable_cpu
=
True
):
self
.
reset_parameters
(
weight_initializer
,
bias_initializer
)
def
from_native_module
(
module
:
nn
.
Linear
,
process_group
:
Union
[
ProcessGroup
,
List
[
ProcessGroup
]],
*
args
,
**
kwargs
)
->
ParallelModule
:
r
"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
# get the attributes
in_features
=
module
.
in_features
out_features
=
module
.
out_features
bias
=
module
.
bias
is
not
None
device
=
module
.
weight
.
device
# ensure only one process group is passed
if
isinstance
(
process_group
,
(
list
,
tuple
)):
assert
len
(
process_group
)
==
1
,
\
f
'Expected only one process group, got
{
len
(
process_group
)
}
.'
process_group
=
process_group
[
0
]
linear_1d
=
Linear1D_Col
(
in_features
=
in_features
,
out_features
=
out_features
,
bias
=
bias
,
bias
=
bias
,
dtype
=
dtype
,
device
=
device
,
skip_bias_add
=
skip_bias_add
,
process_group
=
process_group
,
weight_initializer
=
weight_initializer
,
*
args
,
bias_initializer
=
bias_initializer
)
**
kwargs
)
# TODO: copy the sharded weights
with
torch
.
no_grad
():
# the weigh to the linear layer is a transpose
# thus shard on row is equal to shard on column
sharded_weight
=
shard_rowwise
(
module
.
weight
.
data
,
process_group
)
linear_1d
.
weight
.
data
.
copy_
(
sharded_weight
)
if
bias
:
sharded_bias
=
shard_colwise
(
module
.
bias
.
data
,
process_group
)
linear_1d
.
bias
.
copy_
(
sharded_bias
)
return
linear_1d
def
reset_parameters
(
self
,
weight_initializer
,
bias_initializer
)
->
None
:
fan_in
,
fan_out
=
self
.
in_features
,
self
.
out_features
weight_initializer
(
self
.
weight
,
fan_in
=
fan_in
,
fan_out
=
fan_out
)
if
self
.
bias
is
not
None
:
bias_initializer
(
self
.
bias
,
fan_in
=
fan_in
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tuple
[
Tensor
,
Tensor
]:
assert
input_
.
shape
[
-
1
]
==
self
.
weight
.
shape
[
-
1
],
\
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'
.
format
(
input_
.
shape
,
self
.
weight
.
shape
,
self
.
weight
.
shape
[
-
1
])
# Set up backprop all-reduce.
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
input_parallel
=
input_
# Matrix multiply.
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
output_parallel
=
linear_with_async_comm
(
input_parallel
,
self
.
weight
,
bias
,
self
.
process_group
,
True
)
if
self
.
gather_output
:
# All-gather across the partitions.
output
=
gather_forward_split_backward
(
output_parallel
,
dim
=-
1
,
process_group
=
self
.
process_group
)
else
:
output
=
output_parallel
if
self
.
skip_bias_add
:
return
output
,
self
.
bias
else
:
return
output
class
Linear1D_Row
(
ParallelModule
):
r
""" Linear layer with row parallelism
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
bias
:
bool
=
True
,
dtype
:
torch
.
dtype
=
None
,
device
:
torch
.
device
=
None
,
process_group
:
ProcessGroup
=
None
,
parallel_input
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
),
stream_chunk_num
:
int
=
1
):
super
().
__init__
()
self
.
stream_chunk_num
=
stream_chunk_num
# Keep input parameters
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
parallel_input
=
parallel_input
self
.
skip_bias_add
=
skip_bias_add
self
.
process_group
=
process_group
self
.
num_partitions
=
dist
.
get_world_size
(
self
.
process_group
)
if
skip_bias_add
and
not
bias
:
raise
ValueError
(
'cannot skip bias addition if bias is None'
)
# Divide the weight matrix along the last dimension.
self
.
input_size_per_partition
=
divide
(
in_features
,
self
.
num_partitions
)
# Parameters.
# Initialize weight.
if
device
is
None
:
device
=
get_current_device
()
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
out_features
,
self
.
input_size_per_partition
,
**
factory_kwargs
))
if
self
.
stream_chunk_num
>
1
:
# TODO() work for inference only
self
.
chunk_weight
()
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
out_features
,
**
factory_kwargs
))
else
:
else
:
layer
=
Linear1D_Row
(
in_features
,
self
.
bias
=
None
out_features
,
# offset the seed with randomizer index and rank
seed
=
torch
.
random
.
initial_seed
()
self
.
randomizer
=
create_randomizer_with_offset
(
seed
,
process_group
=
self
.
process_group
)
with
self
.
randomizer
.
fork_rng
(
enable_cpu
=
True
):
self
.
reset_parameters
(
weight_initializer
,
bias_initializer
)
@
staticmethod
def
from_native_module
(
module
:
nn
.
Linear
,
process_group
:
Union
[
ProcessGroup
,
List
[
ProcessGroup
]],
*
args
,
**
kwargs
)
->
ParallelModule
:
r
"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
# get the attributes
in_features
=
module
.
in_features
out_features
=
module
.
out_features
bias
=
module
.
bias
is
not
None
device
=
module
.
weight
.
device
# ensure only one process group is passed
if
isinstance
(
process_group
,
(
list
,
tuple
)):
assert
len
(
process_group
)
==
1
,
\
f
'Expected only one process group, got
{
len
(
process_group
)
}
.'
process_group
=
process_group
[
0
]
linear_1d
=
Linear1D_Row
(
in_features
=
in_features
,
out_features
=
out_features
,
bias
=
bias
,
bias
=
bias
,
dtype
=
dtype
,
device
=
device
,
parallel_input
=
parallel_input
,
process_group
=
process_group
,
skip_bias_add
=
skip_bias_add
,
*
args
,
weight_initializer
=
weight_initializer
,
**
kwargs
)
bias_initializer
=
bias_initializer
)
super
().
__init__
(
layer
)
# TODO: copy the sharded weights
with
torch
.
no_grad
():
# the weigh to the linear layer is a transpose
# thus shard on col is equal to shard on row
sharded_weight
=
shard_colwise
(
module
.
weight
.
data
,
process_group
)
linear_1d
.
weight
.
data
.
copy_
(
sharded_weight
)
if
bias
:
linear_1d
.
bias
.
copy_
(
module
.
bias
.
data
)
return
linear_1d
def
chunk_weight
(
self
):
self
.
weight_list
=
torch
.
chunk
(
self
.
weight
,
self
.
stream_chunk_num
,
dim
=
0
)
def
reset_parameters
(
self
,
weight_initializer
,
bias_initializer
)
->
None
:
fan_in
,
fan_out
=
self
.
in_features
,
self
.
out_features
weight_initializer
(
self
.
weight
,
fan_in
=
fan_in
,
fan_out
=
fan_out
)
if
self
.
bias
is
not
None
:
bias_initializer
(
self
.
bias
,
fan_in
=
fan_in
)
if
self
.
process_group
is
None
:
src_rank
=
0
else
:
src_rank
=
dist
.
distributed_c10d
.
_get_global_rank
(
self
.
process_group
,
0
)
dist
.
broadcast
(
self
.
bias
,
src
=
src_rank
,
group
=
self
.
process_group
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
# Set up backprop all-reduce.
if
self
.
parallel_input
:
assert
input_
.
shape
[
-
1
]
==
self
.
weight
.
shape
[
-
1
],
\
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'
.
format
(
input_
.
shape
,
self
.
weight
.
shape
,
self
.
weight
.
shape
[
-
1
])
input_
=
input_
else
:
assert
divide
(
input_
.
shape
[
-
1
],
self
.
num_partitions
)
==
self
.
weight
.
shape
[
-
1
],
\
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'
.
format
(
input_
.
shape
,
self
.
weight
.
shape
,
self
.
weight
.
shape
[
-
1
]
*
self
.
num_partitions
)
input_
=
split_forward_gather_backward
(
input_
,
dim
=-
1
,
process_group
=
self
.
process_group
)
if
self
.
stream_chunk_num
>
1
:
if
self
.
training
:
raise
RuntimeError
(
"use stream_chunk_num=1 in Linear1D_Row for training!"
)
with
torch
.
no_grad
():
output_parallel_list
=
[
None
for
i
in
range
(
self
.
stream_chunk_num
)]
handle_list
=
[]
for
i
in
range
(
self
.
stream_chunk_num
):
output_parallel_list
[
i
]
=
F
.
linear
(
input_
,
self
.
weight_list
[
i
])
handle
=
torch
.
distributed
.
all_reduce
(
output_parallel_list
[
i
],
group
=
self
.
process_group
,
async_op
=
True
)
handle_list
.
append
(
handle
)
# output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
for
handle
in
handle_list
:
handle
.
wait
()
output
=
torch
.
cat
(
output_parallel_list
,
dim
=-
1
)
else
:
output_parallel
=
F
.
linear
(
input_
,
self
.
weight
)
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
output
=
reduce_input
(
output_parallel
,
self
.
process_group
)
if
not
self
.
skip_bias_add
:
if
self
.
bias
is
not
None
:
output
=
output
+
self
.
bias
return
output
else
:
return
output
,
self
.
bias
# @LAYERS.register_module
class
LayerNorm1D
(
ColossalaiModule
):
class
LayerNorm1D
(
ColossalaiModule
):
r
"""
r
"""
Layer Normalization for colossalai
Layer Normalization for colossalai
...
@@ -152,7 +432,6 @@ class LayerNorm1D(ColossalaiModule):
...
@@ -152,7 +432,6 @@ class LayerNorm1D(ColossalaiModule):
super
().
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
super
().
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
# @LAYERS.register_module
class
Classifier1D
(
ParallelLayer
):
class
Classifier1D
(
ParallelLayer
):
r
"""RowLinear with given weight. Classifier of 1D parallelism.
r
"""RowLinear with given weight. Classifier of 1D parallelism.
...
@@ -288,7 +567,6 @@ class Classifier1D(ParallelLayer):
...
@@ -288,7 +567,6 @@ class Classifier1D(ParallelLayer):
return
output
return
output
# @LAYERS.register_module
class
VocabParallelClassifier1D
(
ParallelLayer
):
class
VocabParallelClassifier1D
(
ParallelLayer
):
r
"""ColLinear with given weight. Classifier of 1D parallelism.
r
"""ColLinear with given weight. Classifier of 1D parallelism.
...
@@ -424,317 +702,8 @@ class VocabParallelClassifier1D(ParallelLayer):
...
@@ -424,317 +702,8 @@ class VocabParallelClassifier1D(ParallelLayer):
# @LAYERS.register_module
# @LAYERS.register_module
class
Linear1D_Col
(
ParallelLayer
):
r
"""Linear layer with column parallelism.
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
its second dimension as :math:`A = [A_1, ..., A_p]`.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
bias
:
bool
=
True
,
dtype
:
torch
.
dtype
=
None
,
gather_output
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
)):
super
().
__init__
()
# Keep input parameters
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
gather_output
=
gather_output
self
.
skip_bias_add
=
skip_bias_add
if
skip_bias_add
and
not
bias
:
raise
ValueError
(
'cannot skip bias addition if bias is None'
)
# self.out_features_per_partition = divide(out_features*2, gpc.tensor_parallel_size)
self
.
out_features_per_partition
=
out_features
# Parameters.
# Initialize weight.
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
out_features_per_partition
,
self
.
in_features
,
**
factory_kwargs
))
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
out_features_per_partition
,
**
factory_kwargs
))
else
:
self
.
bias
=
None
with
seed
(
ParallelMode
.
TENSOR
):
self
.
reset_parameters
(
weight_initializer
,
bias_initializer
)
self
.
_set_tensor_parallel_attributes
()
is_parallel_output
=
not
self
.
gather_output
set_parallel_input
(
is_parallel_output
)
def
reset_parameters
(
self
,
weight_initializer
,
bias_initializer
)
->
None
:
fan_in
,
fan_out
=
self
.
in_features
,
self
.
out_features
weight_initializer
(
self
.
weight
,
fan_in
=
fan_in
,
fan_out
=
fan_out
)
if
self
.
bias
is
not
None
:
bias_initializer
(
self
.
bias
,
fan_in
=
fan_in
)
def
_set_tensor_parallel_attributes
(
self
):
num_partition
=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
set_tensor_parallel_attribute_by_partition
(
self
.
weight
,
num_partition
)
if
self
.
bias
is
not
None
:
set_tensor_parallel_attribute_by_partition
(
self
.
bias
,
num_partition
)
def
_load_from_global_state_dict
(
self
,
state_dict
,
prefix
,
*
args
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
# bias
if
self
.
bias
is
not
None
:
bias
=
state_dict
.
pop
(
bias_key
,
None
)
if
bias
is
not
None
:
local_state
[
bias_key
]
=
bias
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_1D
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
})
super
().
_load_from_global_state_dict
(
local_state
,
prefix
,
*
args
)
def
_save_to_global_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
})
if
self
.
bias
is
not
None
:
local_state
[
bias_key
]
=
self
.
bias
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_1D
,
dims
=
{
weight_key
:
0
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
True
},
keep_vars
=
keep_vars
)
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tuple
[
Tensor
,
Tensor
]:
assert
input_
.
shape
[
-
1
]
==
self
.
weight
.
shape
[
-
1
],
\
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'
.
format
(
input_
.
shape
,
self
.
weight
.
shape
,
self
.
weight
.
shape
[
-
1
])
# Set up backprop all-reduce.
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
input_parallel
=
input_
# Matrix multiply.
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
# output_parallel = F.linear(input_parallel, self.weight, bias)
output_parallel
=
linear_with_async_comm
(
input_parallel
,
self
.
weight
,
bias
,
ParallelMode
.
PARALLEL_1D
,
True
)
if
self
.
gather_output
:
# All-gather across the partitions.
output
=
gather_forward_split_backward
(
output_parallel
,
ParallelMode
.
PARALLEL_1D
,
dim
=-
1
)
else
:
output
=
output_parallel
if
self
.
skip_bias_add
:
return
output
,
self
.
bias
else
:
return
output
# @LAYERS.register_module
class
Linear1D_Row
(
ParallelLayer
):
r
""" Linear layer with row parallelism
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
parallel_input (bool, optional): If set to ``True``, it's assumed that the input is split, defaults to False.
skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
bias
:
bool
=
True
,
dtype
:
torch
.
dtype
=
None
,
parallel_input
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
),
stream_chunk_num
:
int
=
1
):
super
().
__init__
()
self
.
stream_chunk_num
=
stream_chunk_num
# Keep input parameters
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
parallel_input
=
parallel_input
self
.
skip_bias_add
=
skip_bias_add
if
skip_bias_add
and
not
bias
:
raise
ValueError
(
'cannot skip bias addition if bias is None'
)
# Divide the weight matrix along the last dimension.
# self.input_size_per_partition = divide(in_features*2, gpc.tensor_parallel_size)
self
.
input_size_per_partition
=
in_features
# Parameters.
# Initialize weight.
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
out_features
,
self
.
input_size_per_partition
,
**
factory_kwargs
))
if
self
.
stream_chunk_num
>
1
:
# TODO() work for inference only
self
.
chunk_weight
()
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
out_features
,
**
factory_kwargs
))
else
:
self
.
bias
=
None
with
seed
(
ParallelMode
.
TENSOR
):
self
.
reset_parameters
(
weight_initializer
,
bias_initializer
)
self
.
_set_tensor_parallel_attributes
()
set_parallel_input
(
False
)
def
chunk_weight
(
self
):
self
.
weight_list
=
torch
.
chunk
(
self
.
weight
,
self
.
stream_chunk_num
,
dim
=
0
)
def
reset_parameters
(
self
,
weight_initializer
,
bias_initializer
)
->
None
:
fan_in
,
fan_out
=
self
.
in_features
,
self
.
out_features
weight_initializer
(
self
.
weight
,
fan_in
=
fan_in
,
fan_out
=
fan_out
)
if
self
.
bias
is
not
None
:
bias_initializer
(
self
.
bias
,
fan_in
=
fan_in
)
broadcast
(
self
.
bias
,
gpc
.
get_ranks_in_group
(
ParallelMode
.
PARALLEL_1D
)[
0
],
ParallelMode
.
PARALLEL_1D
)
def
_set_tensor_parallel_attributes
(
self
):
num_partition
=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
set_tensor_parallel_attribute_by_partition
(
self
.
weight
,
num_partition
)
def
_load_from_global_state_dict
(
self
,
state_dict
,
prefix
,
*
args
):
local_state
=
OrderedDict
()
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
if
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
==
0
:
# weight
weight
=
state_dict
.
pop
(
weight_key
,
None
)
if
weight
is
not
None
:
local_state
[
weight_key
]
=
weight
# bias
if
self
.
bias
is
not
None
:
bias
=
state_dict
.
pop
(
bias_key
,
None
)
if
bias
is
not
None
:
local_state
[
bias_key
]
=
bias
local_state
=
partition_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_1D
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
False
})
super
().
_load_from_global_state_dict
(
local_state
,
prefix
,
*
args
)
def
_save_to_global_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
})
if
self
.
bias
is
not
None
:
local_state
[
bias_key
]
=
self
.
bias
local_state
=
gather_tensor_parallel_state_dict
(
local_state
,
ParallelMode
.
PARALLEL_1D
,
dims
=
{
weight_key
:
-
1
,
bias_key
:
0
},
partition_states
=
{
weight_key
:
True
,
bias_key
:
False
},
keep_vars
=
keep_vars
)
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
# Set up backprop all-reduce.
if
self
.
parallel_input
:
assert
input_
.
shape
[
-
1
]
==
self
.
weight
.
shape
[
-
1
],
\
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'
.
format
(
input_
.
shape
,
self
.
weight
.
shape
,
self
.
weight
.
shape
[
-
1
])
input_
=
input_
else
:
assert
divide
(
input_
.
shape
[
-
1
],
gpc
.
tensor_parallel_size
)
==
self
.
weight
.
shape
[
-
1
],
\
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'
.
format
(
input_
.
shape
,
self
.
weight
.
shape
,
self
.
weight
.
shape
[
-
1
]
*
gpc
.
tensor_parallel_size
)
input_
=
split_forward_gather_backward
(
input_
,
ParallelMode
.
PARALLEL_1D
,
dim
=-
1
)
if
self
.
stream_chunk_num
>
1
:
if
self
.
training
:
raise
RuntimeError
(
"use stream_chunk_num=1 in Linear1D_Row for training!"
)
with
torch
.
no_grad
():
output_parallel_list
=
[
None
for
i
in
range
(
self
.
stream_chunk_num
)]
handle_list
=
[]
for
i
in
range
(
self
.
stream_chunk_num
):
output_parallel_list
[
i
]
=
F
.
linear
(
input_
,
self
.
weight_list
[
i
])
handle
=
torch
.
distributed
.
all_reduce
(
output_parallel_list
[
i
],
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
async_op
=
True
)
handle_list
.
append
(
handle
)
# output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
for
handle
in
handle_list
:
handle
.
wait
()
output
=
torch
.
cat
(
output_parallel_list
,
dim
=-
1
)
else
:
output_parallel
=
F
.
linear
(
input_
,
self
.
weight
)
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
output
=
reduce_input
(
output_parallel
,
ParallelMode
.
PARALLEL_1D
)
if
not
self
.
skip_bias_add
:
if
self
.
bias
is
not
None
:
output
=
output
+
self
.
bias
return
output
else
:
return
output
,
self
.
bias
# @LAYERS.register_module
class
Embedding1D
(
ParallelLayer
):
class
Embedding1D
(
ParallelLayer
):
r
"""Embedding for 1D parallelism.
r
"""Embedding for 1D parallelism.
...
@@ -842,7 +811,6 @@ class Embedding1D(ParallelLayer):
...
@@ -842,7 +811,6 @@ class Embedding1D(ParallelLayer):
return
output
return
output
# @LAYERS.register_module
class
VocabParallelEmbedding1D
(
ParallelLayer
):
class
VocabParallelEmbedding1D
(
ParallelLayer
):
r
"""Embedding parallelized in the vocabulary dimension.
r
"""Embedding parallelized in the vocabulary dimension.
...
@@ -960,7 +928,6 @@ class VocabParallelEmbedding1D(ParallelLayer):
...
@@ -960,7 +928,6 @@ class VocabParallelEmbedding1D(ParallelLayer):
return
output
return
output
# @LAYERS.register_module
class
Dropout1D
(
ParallelLayer
):
class
Dropout1D
(
ParallelLayer
):
"""Dropout layer of 1D parallelism.
"""Dropout layer of 1D parallelism.
...
...
colossalai/shardformer/layer/utils.py
0 → 100644
View file @
015af592
from
contextlib
import
contextmanager
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
class
Randomizer
:
"""
Randomizer enables the program to be executed under a different seed within the context.
Example:
```python
randomizer = Randomizer(seed=1024)
with randomizer.fork():
# do something here with seed 1024
do_something()
```
Args:
seed (int): The random seed to set.
enable_cpu (bool): fork the CPU RNG state as well.
with_index (bool): whether to use the index of the randomizer.
"""
_INDEX
=
0
def
__init__
(
self
,
seed
:
int
):
# TODO: remove colossalai.context.random
self
.
seed
=
seed
# Handle CUDA rng state
# 1. get the current rng state
# 2. set the seed and store the rng state
# 3. recover the original rng state
cuda_original_rng_state
=
torch
.
cuda
.
get_rng_state
()
torch
.
cuda
.
manual_seed
(
seed
)
self
.
cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
torch
.
cuda
.
set_rng_state
(
cuda_original_rng_state
)
# to the same for cpu rng state
cpu_original_rng_state
=
torch
.
get_rng_state
()
torch
.
manual_seed
(
seed
)
self
.
cpu_rng_state
=
torch
.
get_rng_state
()
torch
.
set_rng_state
(
cpu_original_rng_state
)
def
_set_cuda_rng_state
(
self
,
rng_state
):
torch
.
cuda
.
set_rng_state
(
rng_state
)
def
_get_cuda_rng_state
(
self
):
current_state
=
torch
.
cuda
.
get_rng_state
()
return
current_state
def
_set_cpu_rng_state
(
self
,
rng_state
):
torch
.
set_rng_state
(
rng_state
)
def
_get_cpu_rng_state
(
self
):
current_state
=
torch
.
get_rng_state
()
return
current_state
@
contextmanager
def
fork_rng
(
self
,
enable_cpu
:
bool
=
False
):
"""
This is a context manager to change the dropout state and recover the original state.
Usage:
::
>>> with _seed_manager.dropout_mode():
>>> input = super().forward(input)
"""
try
:
current_cuda_rng_state
=
self
.
_get_cuda_rng_state
()
self
.
_set_cuda_rng_state
(
self
.
cuda_rng_state
)
if
enable_cpu
:
current_cpu_rng_state
=
self
.
_get_cpu_rng_state
()
self
.
_set_cpu_rng_state
(
self
.
cpu_rng_state
)
yield
finally
:
self
.
cuda_rng_state
=
self
.
_get_cuda_rng_state
()
self
.
_set_cuda_rng_state
(
current_cuda_rng_state
)
if
enable_cpu
:
self
.
cpu_rng_state
=
self
.
_get_cpu_rng_state
()
self
.
_set_cpu_rng_state
(
current_cpu_rng_state
)
@
staticmethod
def
index
():
"""
Return the index of the randomizer. The index is useful when the user wants
to introduce some randomness in the program.
Note:
The index will increment by one each time this method is called.
Example:
```python
# assume we need a randomizer to init the weight of different layers
# we can use the index of the randomizer to do so that
# each layer has its own randomizer with a different seed
base_seed = torch.random.initial_seed()
seed = base_seed + Randomizer.index()
randomizer = Randomizer(seed)
with randomizer.fork():
init_weights()
```
"""
idx
=
Randomizer
.
_INDEX
Randomizer
.
_INDEX
+=
1
return
idx
def
create_randomizer_with_offset
(
seed
:
int
,
process_group
:
ProcessGroup
=
None
):
"""
Create a randomizer with an offset. The offset is equal to the rank of the process and the index of the randomizer.
Args:
seed (int): The base random seed to set.
enable_cpu (bool): fork the CPU RNG state as well.
process_group (ProcessGroup): the process group to get the rank from.
Returns:
Randomizer: the randomizer with offset.
"""
offset
=
Randomizer
.
index
()
if
dist
.
is_initialized
():
rank
=
dist
.
get_rank
(
process_group
)
offset
+=
rank
seed
+=
offset
return
Randomizer
(
seed
=
seed
)
colossalai/tensor/d_tensor/api.py
0 → 100644
View file @
015af592
from
typing
import
Union
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
from
colossalai.device.device_mesh
import
DeviceMesh
from
.d_tensor
import
DTensor
from
.sharding_spec
import
ShardingSpec
def
shard_rowwise
(
tensor
:
torch
.
Tensor
,
group_or_device_mesh
:
Union
[
ProcessGroup
,
DeviceMesh
]
=
None
)
->
DTensor
:
"""
Shard the first dim of the given tensor
"""
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
if
group_or_device_mesh
is
None
:
group_or_device_mesh
=
dist
.
GroupMember
.
WORLD
if
isinstance
(
group_or_device_mesh
,
ProcessGroup
):
device_mesh
=
DeviceMesh
.
from_process_group
(
group_or_device_mesh
)
else
:
assert
len
(
group_or_device_mesh
.
shape
)
==
1
,
'Only 1D DeviceMesh is accepted for row-wise sharding.'
device_mesh
=
group_or_device_mesh
sharding_spec
=
ShardingSpec
(
dim_size
=
tensor
.
dim
(),
dim_partition_dict
=
{
0
:
[
0
]})
return
DTensor
(
tensor
,
device_mesh
,
sharding_spec
)
def
shard_colwise
(
tensor
:
torch
.
Tensor
,
group_or_device_mesh
:
Union
[
ProcessGroup
,
DeviceMesh
]
=
None
)
->
DTensor
:
"""
Shard the first dim of the given tensor
"""
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
if
group_or_device_mesh
is
None
:
group_or_device_mesh
=
dist
.
GroupMember
.
WORLD
if
isinstance
(
group_or_device_mesh
,
ProcessGroup
):
device_mesh
=
DeviceMesh
.
from_process_group
(
group_or_device_mesh
)
else
:
assert
len
(
group_or_device_mesh
.
shape
)
==
1
,
'Only 1D DeviceMesh is accepted for row-wise sharding.'
device_mesh
=
group_or_device_mesh
sharding_spec
=
ShardingSpec
(
dim_size
=
tensor
.
dim
(),
dim_partition_dict
=
{
-
1
:
[
0
]})
return
DTensor
(
tensor
,
device_mesh
,
sharding_spec
)
colossalai/tensor/d_tensor/layout.py
View file @
015af592
...
@@ -34,7 +34,7 @@ class Layout:
...
@@ -34,7 +34,7 @@ class Layout:
def
get_sharded_shape_per_device
(
self
):
def
get_sharded_shape_per_device
(
self
):
sharded_shape
=
list
(
self
.
entire_shape
)
sharded_shape
=
list
(
self
.
entire_shape
)
for
dim
,
shard_list
in
self
.
sharding_spec
.
dim_partition_dict
.
items
():
for
dim
,
shard_list
in
self
.
sharding_spec
.
dim_partition_dict
.
items
():
mesh_list
=
[
self
.
device_mesh
.
mesh_
shape
[
mesh_dim
]
for
mesh_dim
in
shard_list
]
mesh_list
=
[
self
.
device_mesh
.
shape
[
mesh_dim
]
for
mesh_dim
in
shard_list
]
shard_partitions
=
reduce
(
operator
.
mul
,
mesh_list
,
1
)
shard_partitions
=
reduce
(
operator
.
mul
,
mesh_list
,
1
)
assert
sharded_shape
[
assert
sharded_shape
[
dim
]
%
shard_partitions
==
0
,
f
'Cannot shard dimension
{
dim
}
into
{
shard_partitions
}
partitions.'
dim
]
%
shard_partitions
==
0
,
f
'Cannot shard dimension
{
dim
}
into
{
shard_partitions
}
partitions.'
...
@@ -45,14 +45,15 @@ class Layout:
...
@@ -45,14 +45,15 @@ class Layout:
sharding_spec
=
self
.
sharding_spec
sharding_spec
=
self
.
sharding_spec
# make sure all axes in logical device mesh only be used once
# make sure all axes in logical device mesh only be used once
dim_check_list
=
list
(
range
(
self
.
device_mesh
.
logical_mesh_id
.
dim
()))
if
self
.
device_mesh
.
logical_mesh_id
is
not
None
:
for
dim
,
shard_list
in
sharding_spec
.
dim_partition_dict
.
items
():
dim_check_list
=
list
(
range
(
self
.
device_mesh
.
logical_mesh_id
.
dim
()))
for
element
in
shard_list
:
for
dim
,
shard_list
in
sharding_spec
.
dim_partition_dict
.
items
():
if
element
in
dim_check_list
:
for
element
in
shard_list
:
dim_check_list
.
remove
(
element
)
if
element
in
dim_check_list
:
else
:
dim_check_list
.
remove
(
element
)
raise
DuplicatedShardingDimensionError
(
else
:
f
"find an invalid sharding axis
{
element
}
in dim_partition_dict in tensor dimension
{
dim
}
."
)
raise
DuplicatedShardingDimensionError
(
f
"find an invalid sharding axis
{
element
}
in dim_partition_dict in tensor dimension
{
dim
}
."
)
# make sure that the sharding for a dimension is divisible by the number of devices
# make sure that the sharding for a dimension is divisible by the number of devices
for
dim
,
shard_list
in
sharding_spec
.
dim_partition_dict
.
items
():
for
dim
,
shard_list
in
sharding_spec
.
dim_partition_dict
.
items
():
...
@@ -60,7 +61,7 @@ class Layout:
...
@@ -60,7 +61,7 @@ class Layout:
num_devices
=
1
num_devices
=
1
for
element
in
shard_list
:
for
element
in
shard_list
:
num_devices
*=
self
.
device_mesh
.
mesh_
shape
[
element
]
num_devices
*=
self
.
device_mesh
.
shape
[
element
]
if
tensor_dim_size
%
num_devices
!=
0
:
if
tensor_dim_size
%
num_devices
!=
0
:
raise
ShardingNotDivisibleError
(
raise
ShardingNotDivisibleError
(
...
...
colossalai/tensor/d_tensor/layout_converter.py
View file @
015af592
...
@@ -304,7 +304,7 @@ class LayoutConverter(metaclass=SingletonMeta):
...
@@ -304,7 +304,7 @@ class LayoutConverter(metaclass=SingletonMeta):
process_groups_dict
=
source_layout
.
device_mesh
.
process_groups_dict
process_groups_dict
=
source_layout
.
device_mesh
.
process_groups_dict
# legal sharding dims means the mesh_id is still available to use.
# legal sharding dims means the mesh_id is still available to use.
legal_sharding_dims
=
[
i
for
i
in
range
(
len
(
source_layout
.
device_mesh
.
mesh_
shape
))]
legal_sharding_dims
=
[
i
for
i
in
range
(
len
(
source_layout
.
device_mesh
.
shape
))]
for
dim
,
shard_list
in
source_spec
.
dim_partition_dict
.
items
():
for
dim
,
shard_list
in
source_spec
.
dim_partition_dict
.
items
():
for
element
in
shard_list
:
for
element
in
shard_list
:
legal_sharding_dims
.
remove
(
element
)
legal_sharding_dims
.
remove
(
element
)
...
...
tests/test_shardformer/test_layer/test_linear_1d.py
0 → 100644
View file @
015af592
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
torch.testing
import
assert_close
import
colossalai
from
colossalai.shardformer.layer.layers
import
Linear1D_Col
,
Linear1D_Row
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
def
check_linear_1d_col
():
linear
=
nn
.
Linear
(
32
,
128
).
cuda
()
linear_col
=
Linear1D_Col
.
from_native_module
(
linear
,
process_group
=
None
,
gather_output
=
True
)
assert
linear_col
.
weight
.
shape
==
torch
.
Size
([
64
,
32
])
assert
linear_col
.
bias
.
shape
==
torch
.
Size
([
64
])
# check computation correctness
x
=
torch
.
rand
(
4
,
32
).
cuda
()
out
=
linear
(
x
)
gather_out
=
linear_col
(
x
)
assert_close
(
out
,
gather_out
)
# check backward correctness
out
.
sum
().
backward
()
gather_out
.
sum
().
backward
()
rank
=
dist
.
get_rank
()
target_grad
=
torch
.
chunk
(
linear
.
weight
.
grad
,
2
,
dim
=
0
)[
rank
]
assert_close
(
target_grad
,
linear_col
.
weight
.
grad
)
def
check_linear_1d_row
():
linear
=
nn
.
Linear
(
32
,
128
).
cuda
()
linear_row
=
Linear1D_Row
.
from_native_module
(
linear
,
process_group
=
None
,
parallel_input
=
False
)
assert
linear_row
.
weight
.
shape
==
torch
.
Size
([
128
,
16
])
assert
linear_row
.
bias
.
shape
==
torch
.
Size
([
128
])
# check computation correctness
x
=
torch
.
rand
(
4
,
32
).
cuda
()
out
=
linear
(
x
)
gather_out
=
linear_row
(
x
)
assert_close
(
out
,
gather_out
)
# check backward correctness
out
.
sum
().
backward
()
gather_out
.
sum
().
backward
()
rank
=
dist
.
get_rank
()
target_grad
=
torch
.
chunk
(
linear
.
weight
.
grad
,
2
,
dim
=
1
)[
rank
]
assert_close
(
target_grad
,
linear_row
.
weight
.
grad
)
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
check_linear_1d_col
()
check_linear_1d_row
()
@
rerun_if_address_is_in_use
()
def
test_linear
():
spawn
(
run_dist
,
nprocs
=
2
)
if
__name__
==
'__main__'
:
test_linear
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment