Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e52f9d91
Unverified
Commit
e52f9d91
authored
Nov 14, 2022
by
アマデウス
Committed by
GitHub
Nov 14, 2022
Browse files
[tensorparallel] fixed tp layers (#1938)
parent
cf68cc92
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
107 additions
and
91 deletions
+107
-91
colossalai/nn/layer/parallel_1d/layers.py
colossalai/nn/layer/parallel_1d/layers.py
+1
-2
colossalai/nn/layer/parallel_3d/_operation.py
colossalai/nn/layer/parallel_3d/_operation.py
+86
-79
colossalai/nn/layer/parallel_3d/layers.py
colossalai/nn/layer/parallel_3d/layers.py
+20
-10
No files found.
colossalai/nn/layer/parallel_1d/layers.py
View file @
e52f9d91
...
...
@@ -77,12 +77,11 @@ class Linear1D(ColossalaiModule):
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
)):
parallel_input
=
get_parallel_input
()
if
not
parallel_input
:
if
not
parallel_input
and
not
gather_output
:
layer
=
Linear1D_Col
(
in_features
,
out_features
,
bias
=
bias
,
dtype
=
dtype
,
gather_output
=
gather_output
,
skip_bias_add
=
skip_bias_add
,
weight_initializer
=
weight_initializer
,
bias_initializer
=
bias_initializer
)
...
...
colossalai/nn/layer/parallel_3d/_operation.py
View file @
e52f9d91
...
...
@@ -4,13 +4,15 @@
from
typing
import
Optional
,
Tuple
import
torch
from
colossalai.communication
import
(
all_gather
,
all_reduce
,
broadcast
,
reduce
,
reduce_scatter
)
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
torch
import
Tensor
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
._utils
import
get_parallel_mode_from_env
,
push_async_grad
from
colossalai.communication
import
all_gather
,
all_reduce
,
broadcast
,
reduce
,
reduce_scatter
from
colossalai.constants
import
INPUT_GROUP_3D
,
WEIGHT_GROUP_3D
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
._utils
import
get_parallel_mode_from_env
,
push_async_grad
class
_Linear3D
(
torch
.
autograd
.
Function
):
...
...
@@ -44,7 +46,6 @@ class _Linear3D(torch.autograd.Function):
@
custom_bwd
def
backward
(
ctx
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
input_
,
weight
=
ctx
.
saved_tensors
with
torch
.
no_grad
():
output_grad
=
all_gather
(
output_grad
,
0
,
ctx
.
output_parallel_mode
)
input_grad
=
torch
.
matmul
(
output_grad
,
weight
.
transpose
(
0
,
1
))
...
...
@@ -129,7 +130,6 @@ class _Classifier3D(torch.autograd.Function):
@
custom_bwd
def
backward
(
ctx
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
input_
,
weight
=
ctx
.
saved_tensors
with
torch
.
no_grad
():
weight_grad
=
torch
.
matmul
(
output_grad
.
reshape
(
-
1
,
output_grad
.
shape
[
-
1
]).
transpose
(
0
,
1
),
input_
.
reshape
(
-
1
,
input_
.
shape
[
-
1
]))
weight_grad
=
reduce
(
weight_grad
,
ctx
.
src_rank
,
ctx
.
input_parallel_mode
)
...
...
@@ -224,7 +224,6 @@ class _VocabParallelClassifier3D(torch.autograd.Function):
@
custom_bwd
def
backward
(
ctx
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
input_
,
weight
=
ctx
.
saved_tensors
with
torch
.
no_grad
():
output_grad
=
all_gather
(
output_grad
,
0
,
ctx
.
output_parallel_mode
)
input_grad
=
torch
.
matmul
(
output_grad
,
weight
.
transpose
(
0
,
1
))
...
...
@@ -281,6 +280,30 @@ def vocab_parallel_classifier_3d(
)
@
torch
.
jit
.
script
def
norm_forward
(
x
,
mean
,
sqr_mean
,
weight
,
bias
,
eps
):
mu
=
x
-
mean
var
=
sqr_mean
-
mean
**
2
sigma
=
torch
.
sqrt
(
var
+
eps
)
z
=
mu
/
sigma
output
=
weight
*
z
+
bias
return
output
,
mu
,
sigma
@
torch
.
jit
.
script
def
norm_backward
(
grad
,
mu
,
sigma
,
weight
):
# dbias, dweight = grad, grad * mu / sigma
dz
=
grad
*
weight
dmu
=
dz
/
sigma
dvar
=
dz
*
mu
*
(
-
0.5
)
*
sigma
**
(
-
3
)
dmean
=
-
dmu
dvar
=
torch
.
sum
(
dvar
,
-
1
,
keepdim
=
True
)
dmean
=
torch
.
sum
(
dmean
,
-
1
,
keepdim
=
True
)
return
dmu
,
dmean
,
dvar
class
_Layernorm3D
(
torch
.
autograd
.
Function
):
@
staticmethod
...
...
@@ -294,27 +317,21 @@ class _Layernorm3D(torch.autograd.Function):
bias_id
:
int
,
normalized_shape
:
int
,
eps
:
float
,
input_parallel_mode
:
ParallelMode
,
weight_parallel_mode
:
ParallelMode
,
output_parallel_mode
:
ParallelMode
,
input_x_weight_parallel_mode
:
ParallelMode
,
)
->
Tensor
:
ctx
.
weight_id
=
weight_id
ctx
.
bias_id
=
bias_id
mean
=
all_reduce
(
torch
.
sum
(
input_
,
dim
=-
1
,
keepdim
=
True
),
output_parallel_mode
)
/
normalized_shape
mu
=
input_
-
mean
var
=
all_reduce
(
torch
.
sum
(
mu
**
2
,
dim
=-
1
,
keepdim
=
True
),
output_parallel_mode
)
/
normalized_shape
sigma
=
torch
.
sqrt
(
var
+
eps
)
sum_
=
torch
.
sum
(
input_
,
dim
=-
1
,
keepdim
=
True
)
sqr_sum
=
torch
.
sum
(
input_
**
2
,
dim
=-
1
,
keepdim
=
True
)
mean
,
sqr_mean
=
all_reduce
(
torch
.
stack
((
sum_
,
sqr_sum
)),
output_parallel_mode
)
/
normalized_shape
ctx
.
save_for_backward
(
mu
,
sigma
,
weight
)
output
,
mu
,
sigma
=
norm_forward
(
input_
,
mean
,
sqr_mean
,
weight
,
bias
,
eps
)
z
=
mu
/
sigma
output
=
weight
*
z
+
bias
ctx
.
save_for_backward
(
mu
,
sigma
,
weight
)
ctx
.
normalized_shape
=
normalized_shape
ctx
.
input_parallel_mode
=
input_parallel_mode
ctx
.
weight_parallel_mode
=
weight_parallel_mode
ctx
.
output_parallel_mode
=
output_parallel_mode
ctx
.
input_x_weight_parallel_mode
=
input_x_weight_parallel_mode
...
...
@@ -324,7 +341,6 @@ class _Layernorm3D(torch.autograd.Function):
@
custom_bwd
def
backward
(
ctx
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
mu
,
sigma
,
weight
=
ctx
.
saved_tensors
with
torch
.
no_grad
():
bias_grad
,
weight_grad
=
output_grad
,
output_grad
*
mu
/
sigma
bias_grad
=
torch
.
sum
(
bias_grad
,
dim
=
tuple
(
range
(
len
(
bias_grad
.
shape
))[:
-
1
]))
...
...
@@ -334,13 +350,9 @@ class _Layernorm3D(torch.autograd.Function):
weight_grad
,
op
=
all_reduce
(
weight_grad
,
ctx
.
input_x_weight_parallel_mode
,
async_op
=
True
)
weight_grad
=
push_async_grad
(
op
,
weight_grad
,
ctx
.
weight_id
)
dz
=
output_grad
*
weight
dvar
=
dz
*
mu
*
(
-
0.5
)
*
sigma
**
(
-
3
)
dvar
=
all_reduce
(
torch
.
sum
(
dvar
,
dim
=-
1
,
keepdim
=
True
),
ctx
.
output_parallel_mode
)
dmean
=
dz
*
(
-
1
/
sigma
)
+
dvar
*
-
2
*
mu
/
ctx
.
normalized_shape
dmean
=
all_reduce
(
torch
.
sum
(
dmean
,
dim
=-
1
,
keepdim
=
True
),
ctx
.
output_parallel_mode
)
input_grad
=
dz
/
sigma
+
dvar
*
2
*
mu
/
ctx
.
normalized_shape
+
dmean
/
ctx
.
normalized_shape
dmu
,
dmean
,
dvar
=
norm_backward
(
output_grad
,
mu
,
sigma
,
weight
)
dvar
,
dmean
=
all_reduce
(
torch
.
stack
((
dvar
,
dmean
)),
ctx
.
output_parallel_mode
)
input_grad
=
dmu
+
(
dmean
+
2
*
dvar
*
mu
)
/
ctx
.
normalized_shape
return
input_grad
,
weight_grad
,
bias_grad
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
...
...
@@ -351,8 +363,6 @@ def layernorm_3d(
bias
:
Tensor
,
normalized_shape
:
int
,
eps
:
float
,
input_parallel_mode
:
ParallelMode
,
weight_parallel_mode
:
ParallelMode
,
output_parallel_mode
:
ParallelMode
,
input_x_weight_parallel_mode
:
ParallelMode
,
)
->
Tensor
:
...
...
@@ -368,9 +378,8 @@ def layernorm_3d(
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (float): a value added to the denominator for numerical stability
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
input_x_weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input x weight parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
...
...
@@ -384,8 +393,6 @@ def layernorm_3d(
id
(
bias
),
normalized_shape
,
eps
,
input_parallel_mode
,
weight_parallel_mode
,
output_parallel_mode
,
input_x_weight_parallel_mode
,
)
...
...
colossalai/nn/layer/parallel_3d/layers.py
View file @
e52f9d91
...
...
@@ -5,6 +5,9 @@ from typing import Callable
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch
import
Tensor
from
torch.nn
import
Parameter
from
colossalai.communication
import
all_reduce
,
broadcast
from
colossalai.constants
import
INPUT_GROUP_3D
,
INPUT_X_WEIGHT_3D
,
OUTPUT_GROUP_3D
,
OUTPUT_X_WEIGHT_3D
,
WEIGHT_GROUP_3D
from
colossalai.context
import
ParallelMode
,
seed
...
...
@@ -13,16 +16,25 @@ from colossalai.global_variables import tensor_parallel_env as env
from
colossalai.nn
import
init
as
init
from
colossalai.nn.layer.base_layer
import
ParallelLayer
from
colossalai.registry
import
LAYERS
from
colossalai.utils.checkpointing
import
(
broadcast_state_dict
,
gather_tensor_parallel_state_dict
,
partition_tensor_parallel_state_dict
)
from
colossalai.utils.checkpointing
import
(
broadcast_state_dict
,
gather_tensor_parallel_state_dict
,
partition_tensor_parallel_state_dict
,
)
from
colossalai.utils.cuda
import
get_current_device
from
torch
import
Tensor
from
torch.nn
import
Parameter
from
..utils
import
divide
,
set_tensor_parallel_attribute_by_partition
,
to_2tuple
from
._operation
import
(
all_gather_tensor_3d
,
classifier_3d
,
vocab_parallel_classifier_3d
,
layernorm_3d
,
linear_3d
,
reduce_scatter_tensor_3d
,
split_tensor_3d
,
split_batch_3d
)
from
._utils
import
get_depth_from_env
,
get_parallel_mode_from_env
,
swap_in_out_group
,
register_async_grad_hook
from
._operation
import
(
all_gather_tensor_3d
,
classifier_3d
,
layernorm_3d
,
linear_3d
,
reduce_scatter_tensor_3d
,
split_batch_3d
,
split_tensor_3d
,
vocab_parallel_classifier_3d
,
)
from
._utils
import
get_depth_from_env
,
get_parallel_mode_from_env
,
register_async_grad_hook
,
swap_in_out_group
@
LAYERS
.
register_module
...
...
@@ -144,8 +156,6 @@ class LayerNorm3D(ParallelLayer):
self
.
bias
,
self
.
normalized_shape
,
self
.
variance_epsilon
,
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
,
self
.
output_parallel_mode
,
self
.
input_x_weight_parallel_mode
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment