Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e52f9d91
Unverified
Commit
e52f9d91
authored
Nov 14, 2022
by
アマデウス
Committed by
GitHub
Nov 14, 2022
Browse files
[tensorparallel] fixed tp layers (#1938)
parent
cf68cc92
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
107 additions
and
91 deletions
+107
-91
colossalai/nn/layer/parallel_1d/layers.py
colossalai/nn/layer/parallel_1d/layers.py
+1
-2
colossalai/nn/layer/parallel_3d/_operation.py
colossalai/nn/layer/parallel_3d/_operation.py
+86
-79
colossalai/nn/layer/parallel_3d/layers.py
colossalai/nn/layer/parallel_3d/layers.py
+20
-10
No files found.
colossalai/nn/layer/parallel_1d/layers.py
View file @
e52f9d91
...
@@ -77,12 +77,11 @@ class Linear1D(ColossalaiModule):
...
@@ -77,12 +77,11 @@ class Linear1D(ColossalaiModule):
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
)):
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
)):
parallel_input
=
get_parallel_input
()
parallel_input
=
get_parallel_input
()
if
not
parallel_input
:
if
not
parallel_input
and
not
gather_output
:
layer
=
Linear1D_Col
(
in_features
,
layer
=
Linear1D_Col
(
in_features
,
out_features
,
out_features
,
bias
=
bias
,
bias
=
bias
,
dtype
=
dtype
,
dtype
=
dtype
,
gather_output
=
gather_output
,
skip_bias_add
=
skip_bias_add
,
skip_bias_add
=
skip_bias_add
,
weight_initializer
=
weight_initializer
,
weight_initializer
=
weight_initializer
,
bias_initializer
=
bias_initializer
)
bias_initializer
=
bias_initializer
)
...
...
colossalai/nn/layer/parallel_3d/_operation.py
View file @
e52f9d91
...
@@ -4,13 +4,15 @@
...
@@ -4,13 +4,15 @@
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
import
torch
import
torch
from
colossalai.communication
import
(
all_gather
,
all_reduce
,
broadcast
,
reduce
,
reduce_scatter
)
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
torch
import
Tensor
from
torch
import
Tensor
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
._utils
import
get_parallel_mode_from_env
,
push_async_grad
from
colossalai.communication
import
all_gather
,
all_reduce
,
broadcast
,
reduce
,
reduce_scatter
from
colossalai.constants
import
INPUT_GROUP_3D
,
WEIGHT_GROUP_3D
from
colossalai.constants
import
INPUT_GROUP_3D
,
WEIGHT_GROUP_3D
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
._utils
import
get_parallel_mode_from_env
,
push_async_grad
class
_Linear3D
(
torch
.
autograd
.
Function
):
class
_Linear3D
(
torch
.
autograd
.
Function
):
...
@@ -44,18 +46,17 @@ class _Linear3D(torch.autograd.Function):
...
@@ -44,18 +46,17 @@ class _Linear3D(torch.autograd.Function):
@
custom_bwd
@
custom_bwd
def
backward
(
ctx
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
def
backward
(
ctx
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
input_
,
weight
=
ctx
.
saved_tensors
input_
,
weight
=
ctx
.
saved_tensors
with
torch
.
no_grad
():
output_grad
=
all_gather
(
output_grad
,
0
,
ctx
.
output_parallel_mode
)
output_grad
=
all_gather
(
output_grad
,
0
,
ctx
.
output_parallel_mode
)
input_grad
=
torch
.
matmul
(
output_grad
,
weight
.
transpose
(
0
,
1
))
input_grad
=
torch
.
matmul
(
output_grad
,
weight
.
transpose
(
0
,
1
))
input_grad
,
input_op
=
reduce_scatter
(
input_grad
,
0
,
ctx
.
input_parallel_mode
,
async_op
=
True
)
input_grad
,
input_op
=
reduce_scatter
(
input_grad
,
0
,
ctx
.
input_parallel_mode
,
async_op
=
True
)
weight_grad
=
torch
.
matmul
(
weight_grad
=
torch
.
matmul
(
input_
.
reshape
(
-
1
,
input_
.
shape
[
-
1
]).
transpose
(
0
,
1
),
output_grad
.
reshape
(
-
1
,
output_grad
.
shape
[
-
1
]))
input_
.
reshape
(
-
1
,
input_
.
shape
[
-
1
]).
transpose
(
0
,
1
),
output_grad
.
reshape
(
-
1
,
output_grad
.
shape
[
-
1
]))
weight_grad
,
op
=
reduce_scatter
(
weight_grad
,
-
1
,
ctx
.
weight_parallel_mode
,
async_op
=
True
)
weight_grad
,
op
=
reduce_scatter
(
weight_grad
,
-
1
,
ctx
.
weight_parallel_mode
,
async_op
=
True
)
weight_grad
=
push_async_grad
(
op
,
weight_grad
,
ctx
.
weight_id
)
weight_grad
=
push_async_grad
(
op
,
weight_grad
,
ctx
.
weight_id
)
input_op
.
wait
()
input_op
.
wait
()
return
input_grad
,
weight_grad
,
None
,
None
,
None
,
None
return
input_grad
,
weight_grad
,
None
,
None
,
None
,
None
...
@@ -129,25 +130,24 @@ class _Classifier3D(torch.autograd.Function):
...
@@ -129,25 +130,24 @@ class _Classifier3D(torch.autograd.Function):
@
custom_bwd
@
custom_bwd
def
backward
(
ctx
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
def
backward
(
ctx
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
input_
,
weight
=
ctx
.
saved_tensors
input_
,
weight
=
ctx
.
saved_tensors
with
torch
.
no_grad
():
weight_grad
=
torch
.
matmul
(
weight_grad
=
torch
.
matmul
(
output_grad
.
reshape
(
-
1
,
output_grad
.
shape
[
-
1
]).
transpose
(
0
,
1
),
input_
.
reshape
(
-
1
,
input_
.
shape
[
-
1
]))
output_grad
.
reshape
(
-
1
,
output_grad
.
shape
[
-
1
]).
transpose
(
0
,
1
),
input_
.
reshape
(
-
1
,
input_
.
shape
[
-
1
]))
weight_grad
=
reduce
(
weight_grad
,
ctx
.
src_rank
,
ctx
.
input_parallel_mode
)
weight_grad
=
reduce
(
weight_grad
,
ctx
.
src_rank
,
ctx
.
input_parallel_mode
)
if
gpc
.
get_local_rank
(
ctx
.
input_parallel_mode
)
==
gpc
.
get_local_rank
(
ctx
.
output_parallel_mode
):
if
gpc
.
get_local_rank
(
ctx
.
input_parallel_mode
)
==
gpc
.
get_local_rank
(
ctx
.
output_parallel_mode
):
weight_grad
,
op
=
all_reduce
(
weight_grad
,
ctx
.
weight_parallel_mode
,
async_op
=
True
)
weight_grad
,
op
=
all_reduce
(
weight_grad
,
ctx
.
weight_parallel_mode
,
async_op
=
True
)
weight_grad
=
push_async_grad
(
op
,
weight_grad
,
ctx
.
weight_id
)
weight_grad
=
push_async_grad
(
op
,
weight_grad
,
ctx
.
weight_id
)
else
:
else
:
weight_grad
=
None
weight_grad
=
None
if
ctx
.
use_bias
:
if
ctx
.
use_bias
:
bias_grad
=
torch
.
sum
(
output_grad
,
dim
=
tuple
(
range
(
len
(
output_grad
.
shape
))[:
-
1
]))
bias_grad
=
torch
.
sum
(
output_grad
,
dim
=
tuple
(
range
(
len
(
output_grad
.
shape
))[:
-
1
]))
bias_grad
=
all_reduce
(
bias_grad
,
ctx
.
input_parallel_mode
)
bias_grad
=
all_reduce
(
bias_grad
,
ctx
.
input_parallel_mode
)
bias_grad
,
op
=
all_reduce
(
bias_grad
,
ctx
.
weight_parallel_mode
,
async_op
=
True
)
bias_grad
,
op
=
all_reduce
(
bias_grad
,
ctx
.
weight_parallel_mode
,
async_op
=
True
)
bias_grad
=
push_async_grad
(
op
,
bias_grad
,
ctx
.
bias_id
)
bias_grad
=
push_async_grad
(
op
,
bias_grad
,
ctx
.
bias_id
)
else
:
else
:
bias_grad
=
None
bias_grad
=
None
input_grad
=
torch
.
matmul
(
output_grad
,
weight
)
input_grad
=
torch
.
matmul
(
output_grad
,
weight
)
return
input_grad
,
weight_grad
,
bias_grad
,
None
,
None
,
None
,
None
,
None
return
input_grad
,
weight_grad
,
bias_grad
,
None
,
None
,
None
,
None
,
None
...
@@ -224,25 +224,24 @@ class _VocabParallelClassifier3D(torch.autograd.Function):
...
@@ -224,25 +224,24 @@ class _VocabParallelClassifier3D(torch.autograd.Function):
@
custom_bwd
@
custom_bwd
def
backward
(
ctx
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
def
backward
(
ctx
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
input_
,
weight
=
ctx
.
saved_tensors
input_
,
weight
=
ctx
.
saved_tensors
with
torch
.
no_grad
():
output_grad
=
all_gather
(
output_grad
,
0
,
ctx
.
output_parallel_mode
)
output_grad
=
all_gather
(
output_grad
,
0
,
ctx
.
output_parallel_mode
)
input_grad
=
torch
.
matmul
(
output_grad
,
weight
.
transpose
(
0
,
1
))
input_grad
=
torch
.
matmul
(
output_grad
,
weight
.
transpose
(
0
,
1
))
input_grad
,
input_op
=
reduce_scatter
(
input_grad
,
0
,
ctx
.
input_parallel_mode
,
async_op
=
True
)
input_grad
,
input_op
=
reduce_scatter
(
input_grad
,
0
,
ctx
.
input_parallel_mode
,
async_op
=
True
)
weight_grad
=
torch
.
matmul
(
weight_grad
=
torch
.
matmul
(
input_
.
reshape
(
-
1
,
input_
.
shape
[
-
1
]).
transpose
(
0
,
1
),
output_grad
.
reshape
(
-
1
,
output_grad
.
shape
[
-
1
]))
input_
.
reshape
(
-
1
,
input_
.
shape
[
-
1
]).
transpose
(
0
,
1
),
output_grad
.
reshape
(
-
1
,
output_grad
.
shape
[
-
1
]))
weight_grad
,
op
=
reduce_scatter
(
weight_grad
.
transpose
(
0
,
1
),
0
,
ctx
.
weight_parallel_mode
,
async_op
=
True
)
weight_grad
,
op
=
reduce_scatter
(
weight_grad
.
transpose
(
0
,
1
),
0
,
ctx
.
weight_parallel_mode
,
async_op
=
True
)
weight_grad
=
push_async_grad
(
op
,
weight_grad
,
ctx
.
weight_id
)
weight_grad
=
push_async_grad
(
op
,
weight_grad
,
ctx
.
weight_id
)
if
ctx
.
use_bias
:
if
ctx
.
use_bias
:
bias_grad
=
torch
.
sum
(
output_grad
,
dim
=
tuple
(
range
(
len
(
output_grad
.
shape
))[:
-
1
]))
bias_grad
=
torch
.
sum
(
output_grad
,
dim
=
tuple
(
range
(
len
(
output_grad
.
shape
))[:
-
1
]))
bias_grad
,
op
=
all_reduce
(
bias_grad
,
ctx
.
weight_parallel_mode
,
async_op
=
True
)
bias_grad
,
op
=
all_reduce
(
bias_grad
,
ctx
.
weight_parallel_mode
,
async_op
=
True
)
bias_grad
=
push_async_grad
(
op
,
bias_grad
,
ctx
.
bias_id
)
bias_grad
=
push_async_grad
(
op
,
bias_grad
,
ctx
.
bias_id
)
else
:
else
:
bias_grad
=
None
bias_grad
=
None
input_op
.
wait
()
input_op
.
wait
()
return
input_grad
,
weight_grad
,
bias_grad
,
None
,
None
,
None
,
None
,
None
return
input_grad
,
weight_grad
,
bias_grad
,
None
,
None
,
None
,
None
,
None
...
@@ -281,6 +280,30 @@ def vocab_parallel_classifier_3d(
...
@@ -281,6 +280,30 @@ def vocab_parallel_classifier_3d(
)
)
@
torch
.
jit
.
script
def
norm_forward
(
x
,
mean
,
sqr_mean
,
weight
,
bias
,
eps
):
mu
=
x
-
mean
var
=
sqr_mean
-
mean
**
2
sigma
=
torch
.
sqrt
(
var
+
eps
)
z
=
mu
/
sigma
output
=
weight
*
z
+
bias
return
output
,
mu
,
sigma
@
torch
.
jit
.
script
def
norm_backward
(
grad
,
mu
,
sigma
,
weight
):
# dbias, dweight = grad, grad * mu / sigma
dz
=
grad
*
weight
dmu
=
dz
/
sigma
dvar
=
dz
*
mu
*
(
-
0.5
)
*
sigma
**
(
-
3
)
dmean
=
-
dmu
dvar
=
torch
.
sum
(
dvar
,
-
1
,
keepdim
=
True
)
dmean
=
torch
.
sum
(
dmean
,
-
1
,
keepdim
=
True
)
return
dmu
,
dmean
,
dvar
class
_Layernorm3D
(
torch
.
autograd
.
Function
):
class
_Layernorm3D
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
...
@@ -294,27 +317,21 @@ class _Layernorm3D(torch.autograd.Function):
...
@@ -294,27 +317,21 @@ class _Layernorm3D(torch.autograd.Function):
bias_id
:
int
,
bias_id
:
int
,
normalized_shape
:
int
,
normalized_shape
:
int
,
eps
:
float
,
eps
:
float
,
input_parallel_mode
:
ParallelMode
,
weight_parallel_mode
:
ParallelMode
,
output_parallel_mode
:
ParallelMode
,
output_parallel_mode
:
ParallelMode
,
input_x_weight_parallel_mode
:
ParallelMode
,
input_x_weight_parallel_mode
:
ParallelMode
,
)
->
Tensor
:
)
->
Tensor
:
ctx
.
weight_id
=
weight_id
ctx
.
weight_id
=
weight_id
ctx
.
bias_id
=
bias_id
ctx
.
bias_id
=
bias_id
mean
=
all_reduce
(
torch
.
sum
(
input_
,
dim
=-
1
,
keepdim
=
True
),
output_parallel_mode
)
/
normalized_shape
sum_
=
torch
.
sum
(
input_
,
dim
=-
1
,
keepdim
=
True
)
mu
=
input_
-
mean
sqr_sum
=
torch
.
sum
(
input_
**
2
,
dim
=-
1
,
keepdim
=
True
)
var
=
all_reduce
(
torch
.
sum
(
mu
**
2
,
dim
=-
1
,
keepdim
=
True
),
output_parallel_mode
)
/
normalized_shape
mean
,
sqr_mean
=
all_reduce
(
torch
.
stack
((
sum_
,
sqr_sum
)),
output_parallel_mode
)
/
normalized_shape
sigma
=
torch
.
sqrt
(
var
+
eps
)
ctx
.
save_for_backward
(
mu
,
sigma
,
weight
)
output
,
mu
,
sigma
=
norm_forward
(
input_
,
mean
,
sqr_mean
,
weight
,
bias
,
eps
)
z
=
mu
/
sigma
ctx
.
save_for_backward
(
mu
,
sigma
,
weight
)
output
=
weight
*
z
+
bias
ctx
.
normalized_shape
=
normalized_shape
ctx
.
normalized_shape
=
normalized_shape
ctx
.
input_parallel_mode
=
input_parallel_mode
ctx
.
weight_parallel_mode
=
weight_parallel_mode
ctx
.
output_parallel_mode
=
output_parallel_mode
ctx
.
output_parallel_mode
=
output_parallel_mode
ctx
.
input_x_weight_parallel_mode
=
input_x_weight_parallel_mode
ctx
.
input_x_weight_parallel_mode
=
input_x_weight_parallel_mode
...
@@ -324,23 +341,18 @@ class _Layernorm3D(torch.autograd.Function):
...
@@ -324,23 +341,18 @@ class _Layernorm3D(torch.autograd.Function):
@
custom_bwd
@
custom_bwd
def
backward
(
ctx
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
def
backward
(
ctx
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
mu
,
sigma
,
weight
=
ctx
.
saved_tensors
mu
,
sigma
,
weight
=
ctx
.
saved_tensors
with
torch
.
no_grad
():
bias_grad
,
weight_grad
=
output_grad
,
output_grad
*
mu
/
sigma
bias_grad
,
weight_grad
=
output_grad
,
output_grad
*
mu
/
sigma
bias_grad
=
torch
.
sum
(
bias_grad
,
dim
=
tuple
(
range
(
len
(
bias_grad
.
shape
))[:
-
1
]))
bias_grad
=
torch
.
sum
(
bias_grad
,
dim
=
tuple
(
range
(
len
(
bias_grad
.
shape
))[:
-
1
]))
bias_grad
,
op
=
all_reduce
(
bias_grad
,
ctx
.
input_x_weight_parallel_mode
,
async_op
=
True
)
bias_grad
,
op
=
all_reduce
(
bias_grad
,
ctx
.
input_x_weight_parallel_mode
,
async_op
=
True
)
bias_grad
=
push_async_grad
(
op
,
bias_grad
,
ctx
.
bias_id
)
bias_grad
=
push_async_grad
(
op
,
bias_grad
,
ctx
.
bias_id
)
weight_grad
=
torch
.
sum
(
weight_grad
,
dim
=
tuple
(
range
(
len
(
weight_grad
.
shape
))[:
-
1
]))
weight_grad
=
torch
.
sum
(
weight_grad
,
dim
=
tuple
(
range
(
len
(
weight_grad
.
shape
))[:
-
1
]))
weight_grad
,
op
=
all_reduce
(
weight_grad
,
ctx
.
input_x_weight_parallel_mode
,
async_op
=
True
)
weight_grad
,
op
=
all_reduce
(
weight_grad
,
ctx
.
input_x_weight_parallel_mode
,
async_op
=
True
)
weight_grad
=
push_async_grad
(
op
,
weight_grad
,
ctx
.
weight_id
)
weight_grad
=
push_async_grad
(
op
,
weight_grad
,
ctx
.
weight_id
)
dz
=
output_grad
*
weight
dvar
=
dz
*
mu
*
(
-
0.5
)
*
sigma
**
(
-
3
)
dvar
=
all_reduce
(
torch
.
sum
(
dvar
,
dim
=-
1
,
keepdim
=
True
),
ctx
.
output_parallel_mode
)
dmean
=
dz
*
(
-
1
/
sigma
)
+
dvar
*
-
2
*
mu
/
ctx
.
normalized_shape
dmean
=
all_reduce
(
torch
.
sum
(
dmean
,
dim
=-
1
,
keepdim
=
True
),
ctx
.
output_parallel_mode
)
input_grad
=
dz
/
sigma
+
dvar
*
2
*
mu
/
ctx
.
normalized_shape
+
dmean
/
ctx
.
normalized_shape
dmu
,
dmean
,
dvar
=
norm_backward
(
output_grad
,
mu
,
sigma
,
weight
)
dvar
,
dmean
=
all_reduce
(
torch
.
stack
((
dvar
,
dmean
)),
ctx
.
output_parallel_mode
)
input_grad
=
dmu
+
(
dmean
+
2
*
dvar
*
mu
)
/
ctx
.
normalized_shape
return
input_grad
,
weight_grad
,
bias_grad
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
input_grad
,
weight_grad
,
bias_grad
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
...
@@ -351,8 +363,6 @@ def layernorm_3d(
...
@@ -351,8 +363,6 @@ def layernorm_3d(
bias
:
Tensor
,
bias
:
Tensor
,
normalized_shape
:
int
,
normalized_shape
:
int
,
eps
:
float
,
eps
:
float
,
input_parallel_mode
:
ParallelMode
,
weight_parallel_mode
:
ParallelMode
,
output_parallel_mode
:
ParallelMode
,
output_parallel_mode
:
ParallelMode
,
input_x_weight_parallel_mode
:
ParallelMode
,
input_x_weight_parallel_mode
:
ParallelMode
,
)
->
Tensor
:
)
->
Tensor
:
...
@@ -368,9 +378,8 @@ def layernorm_3d(
...
@@ -368,9 +378,8 @@ def layernorm_3d(
If a single integer is used, it is treated as a singleton list, and this module will
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
normalize over the last dimension which is expected to be of that specific size.
eps (float): a value added to the denominator for numerical stability
eps (float): a value added to the denominator for numerical stability
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
input_x_weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input x weight parallel mode.
Note:
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
...
@@ -384,8 +393,6 @@ def layernorm_3d(
...
@@ -384,8 +393,6 @@ def layernorm_3d(
id
(
bias
),
id
(
bias
),
normalized_shape
,
normalized_shape
,
eps
,
eps
,
input_parallel_mode
,
weight_parallel_mode
,
output_parallel_mode
,
output_parallel_mode
,
input_x_weight_parallel_mode
,
input_x_weight_parallel_mode
,
)
)
...
...
colossalai/nn/layer/parallel_3d/layers.py
View file @
e52f9d91
...
@@ -5,6 +5,9 @@ from typing import Callable
...
@@ -5,6 +5,9 @@ from typing import Callable
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
Tensor
from
torch.nn
import
Parameter
from
colossalai.communication
import
all_reduce
,
broadcast
from
colossalai.communication
import
all_reduce
,
broadcast
from
colossalai.constants
import
INPUT_GROUP_3D
,
INPUT_X_WEIGHT_3D
,
OUTPUT_GROUP_3D
,
OUTPUT_X_WEIGHT_3D
,
WEIGHT_GROUP_3D
from
colossalai.constants
import
INPUT_GROUP_3D
,
INPUT_X_WEIGHT_3D
,
OUTPUT_GROUP_3D
,
OUTPUT_X_WEIGHT_3D
,
WEIGHT_GROUP_3D
from
colossalai.context
import
ParallelMode
,
seed
from
colossalai.context
import
ParallelMode
,
seed
...
@@ -13,16 +16,25 @@ from colossalai.global_variables import tensor_parallel_env as env
...
@@ -13,16 +16,25 @@ from colossalai.global_variables import tensor_parallel_env as env
from
colossalai.nn
import
init
as
init
from
colossalai.nn
import
init
as
init
from
colossalai.nn.layer.base_layer
import
ParallelLayer
from
colossalai.nn.layer.base_layer
import
ParallelLayer
from
colossalai.registry
import
LAYERS
from
colossalai.registry
import
LAYERS
from
colossalai.utils.checkpointing
import
(
broadcast_state_dict
,
gather_tensor_parallel_state_dict
,
from
colossalai.utils.checkpointing
import
(
partition_tensor_parallel_state_dict
)
broadcast_state_dict
,
gather_tensor_parallel_state_dict
,
partition_tensor_parallel_state_dict
,
)
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
torch
import
Tensor
from
torch.nn
import
Parameter
from
..utils
import
divide
,
set_tensor_parallel_attribute_by_partition
,
to_2tuple
from
..utils
import
divide
,
set_tensor_parallel_attribute_by_partition
,
to_2tuple
from
._operation
import
(
all_gather_tensor_3d
,
classifier_3d
,
vocab_parallel_classifier_3d
,
layernorm_3d
,
linear_3d
,
from
._operation
import
(
reduce_scatter_tensor_3d
,
split_tensor_3d
,
split_batch_3d
)
all_gather_tensor_3d
,
from
._utils
import
get_depth_from_env
,
get_parallel_mode_from_env
,
swap_in_out_group
,
register_async_grad_hook
classifier_3d
,
layernorm_3d
,
linear_3d
,
reduce_scatter_tensor_3d
,
split_batch_3d
,
split_tensor_3d
,
vocab_parallel_classifier_3d
,
)
from
._utils
import
get_depth_from_env
,
get_parallel_mode_from_env
,
register_async_grad_hook
,
swap_in_out_group
@
LAYERS
.
register_module
@
LAYERS
.
register_module
...
@@ -144,8 +156,6 @@ class LayerNorm3D(ParallelLayer):
...
@@ -144,8 +156,6 @@ class LayerNorm3D(ParallelLayer):
self
.
bias
,
self
.
bias
,
self
.
normalized_shape
,
self
.
normalized_shape
,
self
.
variance_epsilon
,
self
.
variance_epsilon
,
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
,
self
.
output_parallel_mode
,
self
.
output_parallel_mode
,
self
.
input_x_weight_parallel_mode
,
self
.
input_x_weight_parallel_mode
,
)
)
...
@@ -900,7 +910,7 @@ class PatchEmbedding3D(ParallelLayer):
...
@@ -900,7 +910,7 @@ class PatchEmbedding3D(ParallelLayer):
weight_parallel_mode
=
self
.
weight_parallel_mode
)
weight_parallel_mode
=
self
.
weight_parallel_mode
)
output
=
F
.
conv2d
(
input_
,
self
.
weight
,
self
.
bias
,
stride
=
self
.
patch_size
)
output
=
F
.
conv2d
(
input_
,
self
.
weight
,
self
.
bias
,
stride
=
self
.
patch_size
)
if
self
.
flatten
:
if
self
.
flatten
:
output
=
output
.
flatten
(
2
).
transpose
(
1
,
2
)
# BCHW -> BNC
output
=
output
.
flatten
(
2
).
transpose
(
1
,
2
)
# BCHW -> BNC
cls_token
=
self
.
cls_token
.
expand
(
output
.
shape
[
0
],
-
1
,
-
1
)
cls_token
=
self
.
cls_token
.
expand
(
output
.
shape
[
0
],
-
1
,
-
1
)
output
=
torch
.
cat
((
cls_token
,
output
),
dim
=
1
)
output
=
torch
.
cat
((
cls_token
,
output
),
dim
=
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment