Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
70c58cfd
Commit
70c58cfd
authored
Jun 23, 2023
by
Frank Lee
Browse files
[shardformer] supported fused qkv checkpoint (#4073)
parent
0803a614
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
420 additions
and
88 deletions
+420
-88
colossalai/shardformer/layer/_operation.py
colossalai/shardformer/layer/_operation.py
+81
-5
colossalai/shardformer/layer/embedding.py
colossalai/shardformer/layer/embedding.py
+2
-2
colossalai/shardformer/layer/linear.py
colossalai/shardformer/layer/linear.py
+5
-11
colossalai/shardformer/layer/linear_conv.py
colossalai/shardformer/layer/linear_conv.py
+112
-50
colossalai/shardformer/layer/parallel_module.py
colossalai/shardformer/layer/parallel_module.py
+7
-1
colossalai/tensor/d_tensor/__init__.py
colossalai/tensor/d_tensor/__init__.py
+6
-2
colossalai/tensor/d_tensor/api.py
colossalai/tensor/d_tensor/api.py
+127
-0
tests/test_shardformer/test_layer/test_linear_1d.py
tests/test_shardformer/test_layer/test_linear_1d.py
+59
-5
tests/test_shardformer/test_layer/test_linearconv_1d.py
tests/test_shardformer/test_layer/test_linearconv_1d.py
+14
-6
tests/test_shardformer/test_model/test_shard_gpt2.py
tests/test_shardformer/test_model/test_shard_gpt2.py
+7
-6
No files found.
colossalai/shardformer/layer/_operation.py
View file @
70c58cfd
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
try
:
import
fused_mix_prec_layer_norm_cuda
...
...
@@ -46,7 +47,7 @@ class FusedLayerNormAffineFunction1D(torch.autograd.Function):
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
class
Linear
WithAsyncCommunication
(
torch
.
autograd
.
Function
):
class
Matmul
WithAsyncCommunication
(
torch
.
autograd
.
Function
):
"""
Linear layer execution with asynchronous communication in backprop.
"""
...
...
@@ -58,11 +59,59 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
ctx
.
process_group
=
process_group
ctx
.
async_grad_allreduce
=
async_grad_allreduce
output
=
torch
.
matmul
(
input_
,
weight
.
t
())
output
=
torch
.
matmul
(
input_
,
weight
)
if
bias
is
not
None
:
output
=
output
+
bias
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
,
weight
=
ctx
.
saved_tensors
use_bias
=
ctx
.
use_bias
total_input
=
input
grad_input
=
grad_output
.
matmul
(
weight
.
T
)
grad_output
=
grad_output
.
contiguous
()
# Convert the tensor shapes to 2D for execution compatibility
if
len
(
grad_output
.
shape
)
>
2
:
grad_output
=
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
])
total_input
=
total_input
.
view
(
-
1
,
total_input
.
shape
[
-
1
])
if
ctx
.
async_grad_allreduce
:
# Asynchronous all-reduce
handle
=
dist
.
all_reduce
(
grad_input
,
group
=
ctx
.
process_group
,
async_op
=
True
)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
grad_weight
=
total_input
.
t
().
matmul
(
grad_output
)
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
if
ctx
.
async_grad_allreduce
:
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
class
LinearWithAsyncCommunication
(
torch
.
autograd
.
Function
):
"""
Linear layer execution with asynchronous communication in backprop.
"""
@
staticmethod
def
forward
(
ctx
,
input_
,
weight
,
bias
,
process_group
,
async_grad_allreduce
):
ctx
.
save_for_backward
(
input_
,
weight
)
ctx
.
use_bias
=
bias
is
not
None
ctx
.
process_group
=
process_group
ctx
.
async_grad_allreduce
=
async_grad_allreduce
if
bias
is
not
None
:
output
=
F
.
linear
(
input_
,
weight
,
bias
)
else
:
output
=
F
.
linear
(
input_
,
weight
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
,
weight
=
ctx
.
saved_tensors
...
...
@@ -114,7 +163,7 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
return
_gather
(
grad_output
,
ctx
.
dim
,
ctx
.
process_group
),
None
,
None
class
_Reduce
Input
(
torch
.
autograd
.
Function
):
class
_Reduce
Forward
(
torch
.
autograd
.
Function
):
"""
All-reduce the input from the model parallel region.
...
...
@@ -132,6 +181,25 @@ class _ReduceInput(torch.autograd.Function):
return
grad_output
,
None
class
_ReduceBackward
(
torch
.
autograd
.
Function
):
"""
All-reduce the input from the model parallel region.
Args:
input_: input matrix.
parallel_mode: parallel mode.
"""
@
staticmethod
def
forward
(
ctx
,
input_
,
process_group
):
ctx
.
process_group
=
process_group
return
input_
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_reduce
(
grad_output
,
ctx
.
process_group
),
None
def
_reduce
(
input_
,
process_group
):
# skip if only one rank involved
if
dist
.
get_world_size
(
process_group
)
==
1
:
...
...
@@ -198,6 +266,10 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
return
_split
(
grad_output
,
ctx
.
dim
,
ctx
.
process_group
),
None
,
None
def
matmul_with_async_comm
(
input_
,
weight
,
bias
,
process_group
,
async_grad_allreduce
):
return
MatmulWithAsyncCommunication
.
apply
(
input_
,
weight
,
bias
,
process_group
,
async_grad_allreduce
)
def
linear_with_async_comm
(
input_
,
weight
,
bias
,
process_group
,
async_grad_allreduce
):
return
LinearWithAsyncCommunication
.
apply
(
input_
,
weight
,
bias
,
process_group
,
async_grad_allreduce
)
...
...
@@ -210,5 +282,9 @@ def split_forward_gather_backward(input_, dim, process_group):
return
_SplitForwardGatherBackward
.
apply
(
input_
,
dim
,
process_group
)
def
reduce_input
(
input_
,
process_group
):
return
_ReduceInput
.
apply
(
input_
,
process_group
)
def
reduce_forward
(
input_
,
process_group
):
return
_ReduceForward
.
apply
(
input_
,
process_group
)
def
reduce_backward
(
input_
,
process_group
):
return
_ReduceBackward
.
apply
(
input_
,
process_group
)
colossalai/shardformer/layer/embedding.py
View file @
70c58cfd
...
...
@@ -15,7 +15,7 @@ from colossalai.nn import init as init
from
colossalai.nn.layer.utils
import
divide
from
colossalai.tensor.d_tensor.api
import
shard_colwise
,
shard_rowwise
,
sharded_tensor_to_param
from
._operation
import
gather_forward_split_backward
,
reduce_
input
from
._operation
import
gather_forward_split_backward
,
reduce_
forward
from
.parallel_module
import
ParallelModule
from
.utils
import
create_randomizer_with_offset
...
...
@@ -276,5 +276,5 @@ class VocabParallelEmbedding1D(ParallelModule):
# Mask the output embedding.
output_parallel
[
input_mask
,
:]
=
0.
# Reduce across all the model parallel GPUs.
output
=
reduce_
input
(
output_parallel
,
self
.
process_group
)
output
=
reduce_
forward
(
output_parallel
,
self
.
process_group
)
return
output
colossalai/shardformer/layer/linear.py
View file @
70c58cfd
...
...
@@ -15,12 +15,11 @@ from torch.nn.parameter import Parameter
from
colossalai.nn
import
init
as
init
from
colossalai.nn.layer.utils
import
divide
from
colossalai.tensor.d_tensor
import
shard_colwise
,
shard_rowwise
,
sharded_tensor_to_param
from
colossalai.utils.cuda
import
get_current_device
from
._operation
import
(
gather_forward_split_backward
,
linear_with_async_comm
,
reduce_
input
,
reduce_
forward
,
split_forward_gather_backward
,
)
from
.parallel_module
import
ParallelModule
...
...
@@ -148,9 +147,10 @@ class Linear1D_Col(ParallelModule):
assert
input_
.
shape
[
-
1
]
==
self
.
weight
.
shape
[
-
1
],
\
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'
.
format
(
input_
.
shape
,
self
.
weight
.
shape
,
self
.
weight
.
shape
[
-
1
])
# Set up backprop all-reduce.
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
input_parallel
=
input_
# Matrix multiply.
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
output_parallel
=
linear_with_async_comm
(
input_parallel
,
self
.
weight
,
bias
,
self
.
process_group
,
True
)
...
...
@@ -209,17 +209,14 @@ class Linear1D_Row(ParallelModule):
self
.
parallel_input
=
parallel_input
self
.
skip_bias_add
=
skip_bias_add
self
.
process_group
=
process_group
self
.
num_partitions
=
dist
.
get_world_size
(
self
.
process_group
)
if
skip_bias_add
and
not
bias
:
raise
ValueError
(
'cannot skip bias addition if bias is None'
)
# Parameters.
# Initialize weight.
if
device
is
None
:
device
=
get_current_device
()
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
weight
=
torch
.
empty
(
self
.
out_features
,
self
.
in_features
,
**
factory_kwargs
)
sharded_weight
=
shard_colwise
(
weight
,
self
.
process_group
)
self
.
weight
=
sharded_tensor_to_param
(
sharded_weight
)
...
...
@@ -327,8 +324,7 @@ class Linear1D_Row(ParallelModule):
output
=
torch
.
cat
(
output_parallel_list
,
dim
=-
1
)
else
:
output_parallel
=
F
.
linear
(
input_
,
self
.
weight
)
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
output
=
reduce_input
(
output_parallel
,
self
.
process_group
)
output
=
reduce_forward
(
output_parallel
,
self
.
process_group
)
if
not
self
.
skip_bias_add
:
if
self
.
bias
is
not
None
:
...
...
@@ -336,5 +332,3 @@ class Linear1D_Row(ParallelModule):
return
output
else
:
return
output
,
self
.
bias
return
output
,
self
.
bias
return
output
,
self
.
bias
colossalai/shardformer/layer/linear_conv.py
View file @
70c58cfd
...
...
@@ -14,13 +14,18 @@ from torch.nn.parameter import Parameter
from
colossalai.nn
import
init
as
init
from
colossalai.nn.layer.utils
import
divide
from
colossalai.tensor.d_tensor.api
import
shard_colwise
,
shard_rowwise
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.tensor.d_tensor.api
import
(
customized_distributed_tensor_to_param
,
distribute_tensor_with_customization
,
shard_rowwise
,
sharded_tensor_to_param
,
)
from
._operation
import
(
gather_forward_split_backward
,
linear_with_async_comm
,
reduce_input
,
matmul_with_async_comm
,
reduce_backward
,
reduce_forward
,
split_forward_gather_backward
,
)
from
.parallel_module
import
ParallelModule
...
...
@@ -29,11 +34,69 @@ from .utils import create_randomizer_with_offset
__all__
=
[
'LinearConv1D_Col'
,
'LinearConv1D_Row'
]
def
split_fused_qkv
(
qkv
:
torch
.
Tensor
,
n_fused
:
int
,
process_group
:
ProcessGroup
):
"""
The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2].
"""
# get the number of slice for the fused qkv
rank
=
dist
.
get_rank
(
group
=
process_group
)
world_size
=
dist
.
get_world_size
(
group
=
process_group
)
order
=
torch
.
arange
(
world_size
*
n_fused
)
# split the fused qkv
# from
# [Q, K, V]
# to
# [Q1, Q2, K1, K2, V1, V2]
weight_chunks
=
torch
.
chunk
(
qkv
,
world_size
*
n_fused
,
dim
=-
1
)
# rearrange the slice into the final order
# from
# [Q1, Q2, K1, K2, V1, V2]
# to
# [Q1, K1, V1], [Q2, K2, V2]
weight_chunks_of_current_rank
=
[
weight_chunks
[
i
]
for
i
in
order
[
rank
::
world_size
]]
weight_of_current_rank
=
torch
.
cat
(
weight_chunks_of_current_rank
,
dim
=-
1
)
return
weight_of_current_rank
def
gather_fused_qkv
(
qkv
:
torch
.
Tensor
,
n_fused
:
int
,
process_group
:
ProcessGroup
):
"""
The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2].
"""
world_size
=
dist
.
get_world_size
(
group
=
process_group
)
# gather the tensors
# from
# [Q1, K1, V1], [Q2, K2, V2]
# to
# [Q1, K1, V1, Q2, K2, V2]
origin_device
=
qkv
.
device
qkv
=
qkv
.
cuda
()
gather_list
=
[
torch
.
zeros_like
(
qkv
)
for
_
in
range
(
world_size
)]
dist
.
all_gather
(
gather_list
,
qkv
,
group
=
process_group
)
gather_weight
=
torch
.
cat
(
gather_list
,
dim
=-
1
)
gather_weight
=
gather_weight
.
to
(
origin_device
)
qkv
=
qkv
.
to
(
origin_device
)
# rearrange the tensor slices
# from
# [Q1, K1, V1, Q2, K2, V2]
# to
# [Q1, Q2, K1, K2, V1, V2]
weight_chunks
=
torch
.
chunk
(
gather_weight
,
world_size
*
n_fused
,
dim
=-
1
)
reordered_chunk_list
=
[]
for
i
in
range
(
n_fused
):
reordered_chunk_list
.
extend
(
weight_chunks
[
i
::
n_fused
])
reordered_gather_weight
=
torch
.
cat
(
reordered_chunk_list
,
dim
=-
1
)
return
reordered_gather_weight
class
LinearConv1D_Col
(
ParallelModule
):
r
"""Linear layer with column parallelism.
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer in gpt2 of huggingface.
its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer
(Fused QKV)
in gpt2 of huggingface.
Args:
in_features (int): size of each input sample.
...
...
@@ -41,6 +104,7 @@ class LinearConv1D_Col(ParallelModule):
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None.
n_fused (int): The number items fused, defaults to 3 (QKV).
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
...
...
@@ -63,8 +127,10 @@ class LinearConv1D_Col(ParallelModule):
dtype
:
torch
.
dtype
=
None
,
device
:
torch
.
device
=
None
,
process_group
:
ProcessGroup
=
None
,
async_communication
:
bool
=
False
,
gather_output
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
n_fused
:
int
=
3
,
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
)):
super
().
__init__
()
...
...
@@ -75,23 +141,34 @@ class LinearConv1D_Col(ParallelModule):
self
.
gather_output
=
gather_output
self
.
skip_bias_add
=
skip_bias_add
self
.
device
=
device
self
.
n_fused
=
n_fused
self
.
process_group
=
process_group
self
.
num_parti
tion
s
=
dist
.
get_world_size
(
self
.
process_group
)
self
.
async_communica
tion
=
async_communication
if
skip_bias_add
and
not
bias
:
raise
ValueError
(
'cannot skip bias addition if bias is None'
)
self
.
out_features_per_partition
=
divide
(
out_features
,
self
.
num_partitions
)
# Parameters.
# Initialize weight.
if
device
is
None
:
device
=
get_current_device
()
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
out_features_per_partition
,
self
.
in_features
,
**
factory_kwargs
))
weight
=
torch
.
empty
(
self
.
in_features
,
self
.
out_features
,
**
factory_kwargs
)
def
shard_fn
(
tensor
):
return
split_fused_qkv
(
tensor
,
self
.
n_fused
,
self
.
process_group
)
def
gather_fn
(
tensor
):
return
gather_fused_qkv
(
tensor
,
3
,
self
.
process_group
)
with
torch
.
no_grad
():
sharded_weight
=
distribute_tensor_with_customization
(
weight
,
shard_fn
,
gather_fn
)
self
.
weight
=
customized_distributed_tensor_to_param
(
sharded_weight
)
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
out_features_per_partition
,
**
factory_kwargs
))
bias
=
torch
.
empty
(
self
.
out_features
,
**
factory_kwargs
)
with
torch
.
no_grad
():
sharded_bias
=
distribute_tensor_with_customization
(
bias
,
shard_fn
,
gather_fn
)
self
.
bias
=
customized_distributed_tensor_to_param
(
sharded_bias
)
else
:
self
.
bias
=
None
...
...
@@ -103,7 +180,7 @@ class LinearConv1D_Col(ParallelModule):
self
.
reset_parameters
(
weight_initializer
,
bias_initializer
)
@
staticmethod
def
from_native_module
(
module
:
nn
.
Linear
,
process_group
:
Union
[
ProcessGroup
,
List
[
ProcessGroup
]],
n_fused
:
int
,
def
from_native_module
(
module
:
nn
.
Module
,
process_group
:
Union
[
ProcessGroup
,
List
[
ProcessGroup
]],
n_fused
:
int
,
*
args
,
**
kwargs
)
->
ParallelModule
:
r
"""
Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer.
...
...
@@ -135,29 +212,12 @@ class LinearConv1D_Col(ParallelModule):
# TODO: copy the sharded weights
with
torch
.
no_grad
():
# the weigh to the linear layer is a transpose
# thus shard on row is equal to shard on column
# first rearange the order of weight and bias
world_size
=
dist
.
get_world_size
(
group
=
process_group
)
order
=
torch
.
arange
(
world_size
*
n_fused
)
new_order
=
[]
for
i
in
range
(
world_size
):
new_order
.
append
(
order
[
i
::
world_size
])
new_order
=
torch
.
cat
(
new_order
)
weight_chunks
=
torch
.
chunk
(
module
.
weight
.
data
,
world_size
*
n_fused
,
dim
=
1
)
rearanged_weight_chunks
=
[
weight_chunks
[
i
]
for
i
in
new_order
]
rearanged_weight
=
torch
.
cat
(
rearanged_weight_chunks
,
dim
=
1
)
sharded_weight
=
shard_colwise
(
rearanged_weight
,
process_group
)
linear_1d
.
weight
.
data
.
copy_
(
sharded_weight
.
T
.
contiguous
())
sharded_weight
=
split_fused_qkv
(
module
.
weight
.
data
,
n_fused
=
n_fused
,
process_group
=
process_group
)
linear_1d
.
weight
.
data
.
copy_
(
sharded_weight
.
data
)
if
bias
:
bias_chunks
=
torch
.
chunk
(
module
.
bias
.
data
,
world_size
*
n_fused
,
dim
=
0
)
rearanged_bias_chunks
=
[
bias_chunks
[
i
]
for
i
in
new_order
]
rearanged_bias
=
torch
.
cat
(
rearanged_bias_chunks
,
dim
=
0
)
sharded_bias
=
shard_colwise
(
rearanged_bias
,
process_group
)
linear_1d
.
bias
.
copy_
(
sharded_bias
.
contiguous
())
sharded_bias
=
split_fused_qkv
(
module
.
bias
.
data
,
n_fused
=
n_fused
,
process_group
=
process_group
)
linear_1d
.
bias
.
data
.
copy_
(
sharded_bias
.
data
)
return
linear_1d
...
...
@@ -169,15 +229,18 @@ class LinearConv1D_Col(ParallelModule):
bias_initializer
(
self
.
bias
,
fan_in
=
fan_in
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tuple
[
Tensor
,
Tensor
]:
assert
input_
.
shape
[
-
1
]
==
self
.
weight
.
shape
[
-
1
],
\
assert
input_
.
shape
[
-
1
]
==
self
.
weight
.
shape
[
0
],
\
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'
.
format
(
input_
.
shape
,
self
.
weight
.
shape
,
self
.
weight
.
shape
[
-
1
])
# Set up backprop all-reduce.
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
input_parallel
=
input_
input_parallel
=
reduce_backward
(
input_
,
self
.
process_group
)
# input_parallel = input_
# Matrix multiply.
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
output_parallel
=
linear_with_async_comm
(
input_parallel
,
self
.
weight
,
bias
,
self
.
process_group
,
True
)
output_parallel
=
matmul_with_async_comm
(
input_parallel
,
self
.
weight
,
bias
,
self
.
process_group
,
self
.
async_communication
)
if
self
.
gather_output
:
# All-gather across the partitions.
...
...
@@ -192,7 +255,8 @@ class LinearConv1D_Col(ParallelModule):
class
LinearConv1D_Row
(
ParallelModule
):
r
""" Linear layer with row parallelism
r
""" Linear layer with row parallelism.
This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface.
Args:
in_features (int): size of each input sample.
...
...
@@ -243,11 +307,10 @@ class LinearConv1D_Row(ParallelModule):
# Parameters.
# Initialize weight.
if
device
is
None
:
device
=
get_current_device
()
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
out_features
,
self
.
input_size_per_partition
,
**
factory_kwargs
))
weight
=
torch
.
empty
(
self
.
in_features
,
self
.
out_features
,
**
factory_kwargs
)
sharded_weight
=
shard_rowwise
(
weight
,
self
.
process_group
)
self
.
weight
=
sharded_tensor_to_param
(
sharded_weight
)
if
self
.
stream_chunk_num
>
1
:
# TODO() work for inference only
...
...
@@ -295,7 +358,7 @@ class LinearConv1D_Row(ParallelModule):
# the weigh to the linear layer is a transpose
# thus shard on col is equal to shard on row
sharded_weight
=
shard_rowwise
(
module
.
weight
.
data
,
process_group
)
linear_1d
.
weight
.
data
.
copy_
(
sharded_weight
.
T
.
contiguous
()
)
linear_1d
.
weight
.
data
.
copy_
(
sharded_weight
.
data
)
if
bias
:
linear_1d
.
bias
.
copy_
(
module
.
bias
.
data
)
...
...
@@ -325,12 +388,12 @@ class LinearConv1D_Row(ParallelModule):
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
# Set up backprop all-reduce.
if
self
.
parallel_input
:
assert
input_
.
shape
[
-
1
]
==
self
.
weight
.
shape
[
-
1
],
\
assert
input_
.
shape
[
-
1
]
==
self
.
weight
.
shape
[
0
],
\
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'
.
format
(
input_
.
shape
,
self
.
weight
.
shape
,
self
.
weight
.
shape
[
-
1
])
input_
=
input_
else
:
assert
divide
(
input_
.
shape
[
-
1
],
self
.
num_partitions
)
==
self
.
weight
.
shape
[
-
1
],
\
assert
divide
(
input_
.
shape
[
-
1
],
self
.
num_partitions
)
==
self
.
weight
.
shape
[
0
],
\
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'
.
format
(
input_
.
shape
,
self
.
weight
.
shape
,
self
.
weight
.
shape
[
-
1
]
*
self
.
num_partitions
)
input_
=
split_forward_gather_backward
(
input_
,
dim
=-
1
,
process_group
=
self
.
process_group
)
...
...
@@ -342,7 +405,7 @@ class LinearConv1D_Row(ParallelModule):
output_parallel_list
=
[
None
for
i
in
range
(
self
.
stream_chunk_num
)]
handle_list
=
[]
for
i
in
range
(
self
.
stream_chunk_num
):
output_parallel_list
[
i
]
=
F
.
linear
(
input_
,
self
.
weight_list
[
i
])
output_parallel_list
[
i
]
=
torch
.
matmul
(
input_
,
self
.
weight_list
[
i
])
handle
=
torch
.
distributed
.
all_reduce
(
output_parallel_list
[
i
],
group
=
self
.
process_group
,
async_op
=
True
)
...
...
@@ -352,9 +415,8 @@ class LinearConv1D_Row(ParallelModule):
handle
.
wait
()
output
=
torch
.
cat
(
output_parallel_list
,
dim
=-
1
)
else
:
output_parallel
=
F
.
linear
(
input_
,
self
.
weight
)
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
output
=
reduce_input
(
output_parallel
,
self
.
process_group
)
output_parallel
=
torch
.
matmul
(
input_
,
self
.
weight
)
output
=
reduce_forward
(
output_parallel
,
self
.
process_group
)
if
not
self
.
skip_bias_add
:
if
self
.
bias
is
not
None
:
...
...
colossalai/shardformer/layer/parallel_module.py
View file @
70c58cfd
...
...
@@ -12,11 +12,14 @@ from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module
from
colossalai.tensor.d_tensor
import
(
distribute_tensor
,
distribute_tensor_with_customization
,
get_device_mesh
,
get_sharding_spec
,
is_customized_distributed_tensor
,
is_distributed_tensor
,
sharded_tensor_to_param
,
to_global
,
to_global_for_customized_distributed_tensor
,
)
__all__
=
[
'ParallelModule'
]
...
...
@@ -54,9 +57,10 @@ class ParallelModule(nn.Module, ABC):
for
name
,
param
in
self
.
_parameters
.
items
():
if
param
is
not
None
:
param_
=
param
if
keep_vars
else
param
.
detach
()
if
is_distributed_tensor
(
param_
):
destination
[
prefix
+
name
]
=
to_global
(
param_
)
elif
is_customized_distributed_tensor
(
param_
):
destination
[
prefix
+
name
]
=
to_global_for_customized_distributed_tensor
(
param_
)
else
:
destination
[
prefix
+
name
]
=
param_
...
...
@@ -124,6 +128,8 @@ class ParallelModule(nn.Module, ABC):
sharding_spec
=
get_sharding_spec
(
param
)
sharded_tensor
=
distribute_tensor
(
input_param
,
device_mesh
,
sharding_spec
)
input_param
=
sharded_tensor_to_param
(
sharded_tensor
)
elif
is_customized_distributed_tensor
(
param
):
input_param
=
distribute_tensor_with_customization
(
input_param
,
param
.
shard_fn
,
param
.
gather_fn
)
# This is used to avoid copying uninitialized parameters into
# non-lazy modules, since they dont have the hook to do the checks
...
...
colossalai/tensor/d_tensor/__init__.py
View file @
70c58cfd
from
.api
import
(
compute_global_numel
,
customized_distributed_tensor_to_param
,
distribute_tensor
,
distribute_tensor_with_customization
,
get_device_mesh
,
get_global_shape
,
get_layout
,
get_sharding_spec
,
is_customized_distributed_tensor
,
is_distributed_tensor
,
is_sharded
,
redistribute
,
...
...
@@ -12,6 +15,7 @@ from .api import (
shard_rowwise
,
sharded_tensor_to_param
,
to_global
,
to_global_for_customized_distributed_tensor
,
)
from
.layout
import
Layout
from
.sharding_spec
import
ShardingSpec
...
...
@@ -19,6 +23,6 @@ from .sharding_spec import ShardingSpec
__all__
=
[
'is_distributed_tensor'
,
'distribute_tensor'
,
'to_global'
,
'is_sharded'
,
'shard_rowwise'
,
'shard_colwise'
,
'sharded_tensor_to_param'
,
'compute_global_numel'
,
'get_sharding_spec'
,
'get_global_shape'
,
'get_device_mesh'
,
'redistribute'
,
'get_layout'
'Layout'
,
'ShardingSpec'
'redistribute'
,
'get_layout'
,
'is_customized_distributed_tensor'
,
'distribute_tensor_with_customization'
,
'to_global_for_customized_distributed_tensor'
,
'customized_distributed_tensor_to_param'
,
'Layout'
,
'ShardingSpec'
]
colossalai/tensor/d_tensor/api.py
View file @
70c58cfd
...
...
@@ -305,3 +305,130 @@ def get_sharding_spec(dtensor: torch.Tensor) -> ShardingSpec:
"""
assert
is_distributed_tensor
(
dtensor
),
'The input tensor is not a distributed tensor.'
return
dtensor
.
dist_layout
.
sharding_spec
# ======================================================
# Some sharding does not obey the SPMD style
# e.g. Fused QKV layer in GPT2
# we support customize sharding with the following APIs
# ======================================================
def
is_customized_distributed_tensor
(
tensor
:
torch
.
Tensor
):
"""
Check whether the given tensor is a customized distributed tensor.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
bool: Whether the given tensor is a customized distributed tensor.
"""
return
hasattr
(
tensor
,
'shard_fn'
)
and
hasattr
(
tensor
,
'gather_fn'
)
def
_hijack_detach_and_clone_for_customized_distributed_tensor
(
dtensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
Args:
tensor (torch.Tensor): The tensor to be hijacked.
Returns:
torch.Tensor: The hijacked tensor.
"""
dtensor
.
_old_detach
=
dtensor
.
detach
dtensor
.
_old_clone
=
dtensor
.
clone
def
new_detach
(
self
):
t_
=
self
.
_old_detach
()
t_
.
shard_fn
=
self
.
shard_fn
t_
.
gather_fn
=
self
.
gather_fn
return
t_
def
new_clone
(
self
,
*
args
,
**
kwargs
):
t_
=
self
.
_old_clone
(
*
args
,
**
kwargs
)
t_
.
shard_fn
=
self
.
shard_fn
t_
.
gather_fn
=
self
.
gather_fn
return
t_
# bind the new methods to the tensor
dtensor
.
detach
=
new_detach
.
__get__
(
dtensor
)
dtensor
.
clone
=
new_clone
.
__get__
(
dtensor
)
return
dtensor
def
distribute_tensor_with_customization
(
tensor
:
torch
.
Tensor
,
shard_fn
,
gather_fn
:
callable
):
"""
Distribute the given tensor with the given shard_fn and gather_fn.
Example:
```python
# define shard and gather functions
def shard_fn(tensor):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
return tensor.chunk(world_size, dim=0)[rank]
def gather_fn(tensor):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
shard_list = [torch.zeros_like(tensor) for _ in range(world_size)]
torch.distributed.all_gather(shard_list, tensor)
return torch.cat(shard_list, dim=0)
# create a distributed tensor
tensor = torch.rand(4, 4)
dtensor = distribute_tensor_with_customization(tensor, shard_fn, gather_fn)
```
Args:
tensor (torch.Tensor): The tensor to be distributed.
shard_fn (callable): The function to shard the tensor.
gather_fn (callable): The function to gather the tensor.
Returns:
torch.Tensor: The distributed tensor.
"""
assert
callable
(
shard_fn
),
'The shard_fn must be callable.'
assert
callable
(
gather_fn
),
'The gather_fn must be callable.'
assert
not
is_distributed_tensor
(
tensor
),
'The input tensor is already a distributed tensor.'
sharded_tensor
=
shard_fn
(
tensor
)
# set the shard_fn and gather_fn as attributes of the distributed tensor
sharded_tensor
.
shard_fn
=
shard_fn
sharded_tensor
.
gather_fn
=
gather_fn
# set the shard_fn and gather_fn as attributes of the distributed tensor
_hijack_detach_and_clone_for_customized_distributed_tensor
(
sharded_tensor
)
return
sharded_tensor
def
to_global_for_customized_distributed_tensor
(
dtensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Gather the given tensor to the global tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
torch.Tensor: The global tensor.
"""
assert
is_customized_distributed_tensor
(
dtensor
),
'The input tensor is not a customized distributed tensor.'
return
dtensor
.
gather_fn
(
dtensor
)
def
customized_distributed_tensor_to_param
(
dtensor
:
torch
.
Tensor
,
requires_grad
:
bool
=
True
):
"""
Convert the given customized distributed tensor to a parameter.
"""
assert
is_customized_distributed_tensor
(
dtensor
),
'The input tensor is not a customized distributed tensor.'
param
=
torch
.
nn
.
Parameter
(
dtensor
,
requires_grad
=
requires_grad
)
# make it distributed as well
param
.
shard_fn
=
dtensor
.
shard_fn
param
.
gather_fn
=
dtensor
.
gather_fn
_hijack_detach_and_clone_for_customized_distributed_tensor
(
param
)
return
param
tests/test_shardformer/test_layer/test_linear_1d.py
View file @
70c58cfd
...
...
@@ -27,8 +27,13 @@ def check_linear_1d_col():
# check computation correctness
x
=
torch
.
rand
(
4
,
32
).
cuda
()
out
=
linear
(
x
)
gather_out
=
linear_col
(
x
)
x_for_unshard
=
x
.
expand_as
(
x
.
clone
())
x_for_unshard
.
requires_grad_
(
True
)
x_for_shard
=
x
.
expand_as
(
x
.
clone
())
x_for_shard
.
requires_grad_
(
True
)
out
=
linear
(
x_for_unshard
)
gather_out
=
linear_col
(
x_for_shard
)
assert_close
(
out
,
gather_out
)
# check backward correctness
...
...
@@ -39,6 +44,11 @@ def check_linear_1d_col():
target_grad
=
torch
.
chunk
(
linear
.
weight
.
grad
,
2
,
dim
=
0
)[
rank
]
assert_close
(
target_grad
,
linear_col
.
weight
.
grad
)
# check the input gradients
assert
x_for_shard
.
grad
is
not
None
assert
x_for_unshard
.
grad
is
not
None
assert_close
(
x_for_unshard
.
grad
,
x_for_shard
.
grad
)
def
check_linear_1d_row
():
linear
=
nn
.
Linear
(
32
,
128
).
cuda
()
...
...
@@ -49,8 +59,14 @@ def check_linear_1d_row():
# check computation correctness
x
=
torch
.
rand
(
4
,
32
).
cuda
()
out
=
linear
(
x
)
gather_out
=
linear_row
(
x
)
x_for_unshard
=
x
.
expand_as
(
x
.
clone
())
x_for_unshard
.
requires_grad_
(
True
)
x_for_shard
=
x
.
expand_as
(
x
.
clone
())
x_for_shard
.
requires_grad_
(
True
)
# run forward
out
=
linear
(
x_for_unshard
)
gather_out
=
linear_row
(
x_for_shard
)
assert_close
(
out
,
gather_out
)
# check backward correctness
...
...
@@ -61,11 +77,49 @@ def check_linear_1d_row():
target_grad
=
torch
.
chunk
(
linear
.
weight
.
grad
,
2
,
dim
=
1
)[
rank
]
assert_close
(
target_grad
,
linear_row
.
weight
.
grad
)
# check the input gradients
assert
x_for_shard
.
grad
is
not
None
assert
x_for_unshard
.
grad
is
not
None
assert_close
(
x_for_unshard
.
grad
,
x_for_shard
.
grad
)
def
check_linear_col_plus_row
():
linear_1
=
nn
.
Linear
(
32
,
128
).
cuda
()
linear_2
=
nn
.
Linear
(
128
,
32
).
cuda
()
linear_col
=
Linear1D_Col
.
from_native_module
(
linear_1
,
process_group
=
None
,
gather_output
=
False
)
linear_row
=
Linear1D_Row
.
from_native_module
(
linear_2
,
process_group
=
None
,
parallel_input
=
True
)
# check computation correctness
x
=
torch
.
rand
(
4
,
32
).
cuda
()
x_for_unshard
=
x
.
expand_as
(
x
.
clone
())
x_for_unshard
.
requires_grad_
(
True
)
x_for_shard
=
x
.
expand_as
(
x
.
clone
())
x_for_shard
.
requires_grad_
(
True
)
# run forward
unshard_out
=
linear_2
(
linear_1
(
x_for_unshard
))
shard_out
=
linear_row
(
linear_col
(
x_for_shard
))
assert_close
(
unshard_out
,
shard_out
)
# check backward correctness
unshard_out
.
sum
().
backward
()
shard_out
.
sum
().
backward
()
rank
=
dist
.
get_rank
()
target_1_grad
=
torch
.
chunk
(
linear_1
.
weight
.
grad
,
2
,
dim
=
0
)[
rank
]
assert_close
(
target_1_grad
,
linear_col
.
weight
.
grad
)
# check the input gradients
assert
x_for_shard
.
grad
is
not
None
assert
x_for_unshard
.
grad
is
not
None
assert_close
(
x_for_unshard
.
grad
,
x_for_shard
.
grad
)
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
check_linear_1d_col
()
# check_linear_1d_row()
check_linear_1d_row
()
check_linear_col_plus_row
()
@
rerun_if_address_is_in_use
()
...
...
tests/test_shardformer/test_layer/test_linearconv_1d.py
View file @
70c58cfd
...
...
@@ -5,6 +5,7 @@ from torch.testing import assert_close
import
colossalai
from
colossalai.shardformer.layer
import
LinearConv1D_Col
,
LinearConv1D_Row
from
colossalai.shardformer.layer.linear_conv
import
split_fused_qkv
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
...
...
@@ -53,9 +54,15 @@ def check_linear_conv_1d_col():
linear
=
Conv1D
(
192
,
48
).
cuda
()
linear_conv_col
=
LinearConv1D_Col
.
from_native_module
(
linear
,
process_group
=
None
,
gather_output
=
True
,
n_fused
=
3
)
assert
linear_conv_col
.
weight
.
shape
==
torch
.
Size
([
96
,
48
])
assert
linear
.
weight
.
shape
==
torch
.
Size
([
48
,
192
])
assert
linear
.
bias
.
shape
==
torch
.
Size
([
192
])
assert
linear_conv_col
.
weight
.
shape
==
torch
.
Size
([
48
,
96
])
assert
linear_conv_col
.
bias
.
shape
==
torch
.
Size
([
96
])
# ensure weights are reversibly loadable
linear_conv_col
.
load_state_dict
(
linear
.
state_dict
())
linear
.
load_state_dict
(
linear_conv_col
.
state_dict
())
# check computation correctness
x
=
torch
.
rand
(
4
,
48
).
cuda
()
out
=
linear
(
x
)
...
...
@@ -66,16 +73,16 @@ def check_linear_conv_1d_col():
out
.
sum
().
backward
()
gather_out
.
sum
().
backward
()
rank
=
dist
.
get_rank
()
target_grad
=
torch
.
chunk
(
linear
.
weight
.
grad
,
2
,
dim
=
1
)[
rank
]
assert_close
(
target_grad
.
transpose
(
0
,
1
).
contiguous
(),
linear_conv_col
.
weight
.
grad
)
target_grad
=
split_fused_qkv
(
linear
.
weight
.
grad
,
3
,
None
)
assert_close
(
target_grad
,
linear_conv_col
.
weight
.
grad
)
def
check_linear_1d_row
():
linear
=
Conv1D
(
192
,
48
).
cuda
()
linear_row
=
LinearConv1D_Row
.
from_native_module
(
linear
,
process_group
=
None
,
parallel_input
=
False
)
assert
linear_row
.
weight
.
shape
==
torch
.
Size
([
192
,
24
])
assert
linear
.
weight
.
shape
==
torch
.
Size
([
48
,
192
])
assert
linear_row
.
weight
.
shape
==
torch
.
Size
([
24
,
192
])
assert
linear_row
.
bias
.
shape
==
torch
.
Size
([
192
])
# check computation correctness
...
...
@@ -89,13 +96,14 @@ def check_linear_1d_row():
gather_out
.
sum
().
backward
()
rank
=
dist
.
get_rank
()
target_grad
=
torch
.
chunk
(
linear
.
weight
.
grad
,
2
,
dim
=
1
)[
rank
]
target_grad
=
torch
.
chunk
(
linear
.
weight
.
grad
,
2
,
dim
=
0
)[
rank
]
assert_close
(
target_grad
,
linear_row
.
weight
.
grad
)
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
check_linear_conv_1d_col
()
check_linear_1d_row
()
@
rerun_if_address_is_in_use
()
...
...
tests/test_shardformer/test_model/test_shard_gpt2.py
View file @
70c58cfd
...
...
@@ -20,20 +20,21 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
# check grad equality
if
org_model
.
__class__
.
__name__
==
'GPT2Model'
:
org_grad
=
org_model
.
h
[
0
].
attn
.
c_attn
.
weight
.
grad
shard_grad
=
sharded_model
.
h
[
0
].
attn
.
c_attn
.
weight
.
grad
.
transpose
(
0
,
1
).
contiguous
()
org_grad
=
org_model
.
h
[
0
].
mlp
.
c_fc
.
weight
.
grad
shard_grad
=
sharded_model
.
h
[
0
].
mlp
.
c_fc
.
weight
.
grad
else
:
org_grad
=
org_model
.
transformer
.
h
[
0
].
mlp
.
c_fc
.
weight
.
grad
shard_grad
=
sharded_model
.
transformer
.
h
[
0
].
mlp
.
c_fc
.
weight
.
grad
.
transpose
(
0
,
1
).
contiguous
()
shard_grad
=
sharded_model
.
transformer
.
h
[
0
].
mlp
.
c_fc
.
weight
.
grad
shard_grad_list
=
[
torch
.
zeros
([
*
shard_grad
.
shape
]).
to
(
'cuda'
)
for
_
in
range
(
2
)]
shard_grad
=
torch
.
distributed
.
all_gather
(
shard_grad_list
,
shard_grad
)
all_shard_grad
=
torch
.
cat
(
shard_grad_list
,
dim
=
1
)
assert
torch
.
allclose
(
org_loss
,
shard_loss
,
atol
=
1e-5
),
f
"shard model loss is not equal to orgin model loss
\n
{
org_loss
}
\n
{
shard_loss
}
"
assert
torch
.
allclose
(
org_grad
,
all_shard_grad
,
atol
=
1e-5
),
f
"shard model grad is not equal to orgin model grad
\n
{
org_grad
}
\n
{
all_shard_grad
}
"
atol
=
1e-5
),
f
"shard model loss is not equal to origin model loss
\n
{
org_loss
}
\n
{
shard_loss
}
"
assert
torch
.
allclose
(
org_grad
,
all_shard_grad
,
atol
=
1e-5
),
f
"shard model grad is not equal to origin model grad
\n
{
org_grad
}
\n
{
all_shard_grad
}
"
def
check_gpt2
(
rank
,
world_size
,
port
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment