Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8823cc48
Unverified
Commit
8823cc48
authored
Jan 29, 2024
by
Frank Lee
Committed by
GitHub
Jan 29, 2024
Browse files
Merge pull request #5310 from hpcaitech/feature/npu
Feature/npu
parents
bce9499e
73f4dc57
Changes
266
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
177 additions
and
101 deletions
+177
-101
colossalai/legacy/communication/p2p.py
colossalai/legacy/communication/p2p.py
+7
-3
colossalai/legacy/communication/ring.py
colossalai/legacy/communication/ring.py
+3
-3
colossalai/legacy/communication/utils.py
colossalai/legacy/communication/utils.py
+3
-3
colossalai/legacy/engine/schedule/_base_schedule.py
colossalai/legacy/engine/schedule/_base_schedule.py
+3
-3
colossalai/legacy/engine/schedule/_pipeline_schedule.py
colossalai/legacy/engine/schedule/_pipeline_schedule.py
+3
-3
colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py
colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py
+2
-2
colossalai/legacy/initialize.py
colossalai/legacy/initialize.py
+3
-3
colossalai/legacy/nn/layer/colossalai_layer/embedding.py
colossalai/legacy/nn/layer/colossalai_layer/embedding.py
+2
-2
colossalai/legacy/nn/layer/colossalai_layer/normalization.py
colossalai/legacy/nn/layer/colossalai_layer/normalization.py
+2
-2
colossalai/legacy/nn/layer/parallel_1d/layers.py
colossalai/legacy/nn/layer/parallel_1d/layers.py
+14
-8
colossalai/legacy/nn/layer/parallel_2d/_operation.py
colossalai/legacy/nn/layer/parallel_2d/_operation.py
+4
-4
colossalai/legacy/nn/layer/parallel_2d/layers.py
colossalai/legacy/nn/layer/parallel_2d/layers.py
+27
-12
colossalai/legacy/nn/layer/parallel_2p5d/_operation.py
colossalai/legacy/nn/layer/parallel_2p5d/_operation.py
+8
-6
colossalai/legacy/nn/layer/parallel_2p5d/layers.py
colossalai/legacy/nn/layer/parallel_2p5d/layers.py
+27
-12
colossalai/legacy/nn/layer/parallel_3d/layers.py
colossalai/legacy/nn/layer/parallel_3d/layers.py
+38
-15
colossalai/legacy/nn/layer/parallel_sequence/_operation.py
colossalai/legacy/nn/layer/parallel_sequence/_operation.py
+7
-5
colossalai/legacy/nn/layer/parallel_sequence/layers.py
colossalai/legacy/nn/layer/parallel_sequence/layers.py
+1
-2
colossalai/legacy/nn/layer/vanilla/layers.py
colossalai/legacy/nn/layer/vanilla/layers.py
+19
-9
colossalai/legacy/nn/loss/loss_2d.py
colossalai/legacy/nn/loss/loss_2d.py
+2
-2
colossalai/legacy/nn/loss/loss_2p5d.py
colossalai/legacy/nn/loss/loss_2p5d.py
+2
-2
No files found.
colossalai/legacy/communication/p2p.py
View file @
8823cc48
...
...
@@ -8,9 +8,9 @@ from typing import List, Tuple, Union
import
torch
import
torch.distributed
as
dist
from
colossalai.accelerator
import
get_accelerator
from
colossalai.legacy.context.parallel_mode
import
ParallelMode
from
colossalai.legacy.core
import
global_context
as
gpc
from
colossalai.utils
import
get_current_device
from
.utils
import
gather_split_1d_tensor
,
split_tensor_into_1d_equal_chunks
...
...
@@ -43,12 +43,16 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) ->
def
create_recv_buffer_with_shapes
(
recv_shapes
,
dtype
,
scatter_gather_tensors
):
if
isinstance
(
recv_shapes
,
torch
.
Size
):
recv_chunk_shape
,
recv_split
=
_get_tensor_shape
(
recv_shapes
,
scatter_gather_tensors
)
buffer_recv
=
torch
.
empty
(
recv_chunk_shape
,
requires_grad
=
True
,
device
=
get_current_device
(),
dtype
=
dtype
)
buffer_recv
=
torch
.
empty
(
recv_chunk_shape
,
requires_grad
=
True
,
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
)
return
buffer_recv
,
recv_split
buffer_recv
=
[]
for
recv_shape
in
recv_shapes
:
recv_chunk_shape
,
recv_split
=
_get_tensor_shape
(
recv_shape
,
scatter_gather_tensors
)
tensor_recv
=
torch
.
empty
(
recv_chunk_shape
,
requires_grad
=
True
,
device
=
get_current_device
(),
dtype
=
dtype
)
tensor_recv
=
torch
.
empty
(
recv_chunk_shape
,
requires_grad
=
True
,
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
)
buffer_recv
.
append
(
tensor_recv
)
return
buffer_recv
,
recv_split
...
...
colossalai/legacy/communication/ring.py
View file @
8823cc48
...
...
@@ -3,9 +3,9 @@
import
torch
from
colossalai.accelerator
import
get_accelerator
from
colossalai.legacy.context.parallel_mode
import
ParallelMode
from
colossalai.legacy.core
import
global_context
as
gpc
from
colossalai.utils
import
get_current_device
,
synchronize
def
ring_forward
(
tensor_send_next
:
torch
.
Tensor
,
parallel_mode
:
ParallelMode
)
->
torch
.
Tensor
:
...
...
@@ -29,7 +29,7 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) ->
current_rank
=
gpc
.
get_global_rank
()
tensor_recv_prev
=
torch
.
empty
(
buffer_shape
,
requires_grad
=
True
,
device
=
get_current_device
(),
dtype
=
tensor_send_next
.
dtype
buffer_shape
,
requires_grad
=
True
,
device
=
get_
accelerator
().
get_
current_device
(),
dtype
=
tensor_send_next
.
dtype
)
# send to next rank
...
...
@@ -52,6 +52,6 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) ->
req
.
wait
()
# To protect against race condition when using batch_isend_irecv().
synchronize
()
get_accelerator
().
synchronize
()
return
tensor_recv_prev
colossalai/legacy/communication/utils.py
View file @
8823cc48
...
...
@@ -3,9 +3,9 @@ from typing import List, Tuple, Union
import
torch
import
torch.distributed
as
dist
from
colossalai.accelerator
import
get_accelerator
from
colossalai.legacy.context.parallel_mode
import
ParallelMode
from
colossalai.legacy.core
import
global_context
as
gpc
from
colossalai.utils
import
get_current_device
TensorShape
=
Union
[
torch
.
Size
,
List
[
int
],
Tuple
[
int
]]
...
...
@@ -35,7 +35,7 @@ def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool:
if
next_rank
is
None
:
next_rank
=
gpc
.
get_next_global_rank
(
ParallelMode
.
PIPELINE
)
tensor_kwargs
=
{
"dtype"
:
torch
.
long
,
"device"
:
get_current_device
()}
tensor_kwargs
=
{
"dtype"
:
torch
.
long
,
"device"
:
get_accelerator
().
get_current_device
()}
if
isinstance
(
obj
,
torch
.
Tensor
):
send_obj_nums
=
torch
.
tensor
(
1
,
**
tensor_kwargs
)
dist
.
send
(
send_obj_nums
,
next_rank
)
...
...
@@ -74,7 +74,7 @@ def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size:
if
prev_rank
is
None
:
prev_rank
=
gpc
.
get_prev_global_rank
(
ParallelMode
.
PIPELINE
)
tensor_kwargs
=
{
"dtype"
:
torch
.
long
,
"device"
:
get_current_device
()}
tensor_kwargs
=
{
"dtype"
:
torch
.
long
,
"device"
:
get_accelerator
().
get_current_device
()}
recv_obj_nums
=
torch
.
empty
((),
**
tensor_kwargs
)
dist
.
recv
(
recv_obj_nums
,
prev_rank
)
if
recv_obj_nums
.
item
()
==
1
:
...
...
colossalai/legacy/engine/schedule/_base_schedule.py
View file @
8823cc48
...
...
@@ -6,8 +6,8 @@ from typing import Callable, Iterable
import
torch
from
colossalai.accelerator
import
get_accelerator
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
get_current_device
class
BaseSchedule
(
ABC
):
...
...
@@ -29,12 +29,12 @@ class BaseSchedule(ABC):
def
_move_tensor
(
element
):
if
torch
.
is_tensor
(
element
):
if
not
element
.
is_cuda
:
return
element
.
to
(
get_current_device
()).
detach
()
return
element
.
to
(
get_
accelerator
().
get_
current_device
()).
detach
()
return
element
def
_move_to_device
(
self
,
data
):
if
isinstance
(
data
,
torch
.
Tensor
):
data
=
data
.
to
(
get_current_device
())
data
=
data
.
to
(
get_
accelerator
().
get_
current_device
())
elif
isinstance
(
data
,
(
list
,
tuple
)):
data_to_return
=
[]
for
element
in
data
:
...
...
colossalai/legacy/engine/schedule/_pipeline_schedule.py
View file @
8823cc48
...
...
@@ -7,12 +7,12 @@ from typing import Callable, List, Tuple, Union
import
torch.cuda
import
colossalai.legacy.communication
as
comm
from
colossalai.accelerator
import
get_accelerator
from
colossalai.legacy.amp.naive_amp
import
NaiveAMPModel
from
colossalai.legacy.context.parallel_mode
import
ParallelMode
from
colossalai.legacy.core
import
global_context
as
gpc
from
colossalai.legacy.utils
import
switch_virtual_pipeline_parallel_rank
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils.device
import
get_current_device
from
._base_schedule
import
BaseSchedule
...
...
@@ -352,7 +352,7 @@ class PipelineSchedule(BaseSchedule):
output_objs
=
[]
return_tensors
=
[]
if
return_loss
and
gpc
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
accum_loss
=
torch
.
zeros
(
1
,
device
=
get_current_device
())
accum_loss
=
torch
.
zeros
(
1
,
device
=
get_
accelerator
().
get_
current_device
())
else
:
accum_loss
=
None
# Used for tensor meta information communication
...
...
@@ -584,7 +584,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
if
not
forward_only
:
output_obj_grads
=
[[]
for
_
in
range
(
len
(
model
))]
if
return_loss
and
gpc
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
accum_loss
=
torch
.
zeros
(
1
,
device
=
get_current_device
())
accum_loss
=
torch
.
zeros
(
1
,
device
=
get_
accelerator
().
get_
current_device
())
else
:
accum_loss
=
None
...
...
colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py
View file @
8823cc48
...
...
@@ -6,10 +6,10 @@ from typing import Iterable, Tuple
import
torch.cuda
import
colossalai.legacy.communication.p2p_v2
as
comm
from
colossalai.accelerator
import
get_accelerator
from
colossalai.legacy.context.parallel_mode
import
ParallelMode
from
colossalai.legacy.core
import
global_context
as
gpc
from
colossalai.legacy.engine
import
Engine
from
colossalai.utils.device
import
get_current_device
from
._pipeline_schedule
import
PipelineSchedule
...
...
@@ -99,7 +99,7 @@ class PipelineScheduleV2(PipelineSchedule):
output_objs
=
[]
return_tensors
=
[]
if
return_loss
and
gpc
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
accum_loss
=
torch
.
zeros
(
1
,
device
=
get_current_device
())
accum_loss
=
torch
.
zeros
(
1
,
device
=
get_
accelerator
().
get_
current_device
())
else
:
accum_loss
=
None
...
...
colossalai/legacy/initialize.py
View file @
8823cc48
...
...
@@ -15,6 +15,7 @@ from torch.optim.lr_scheduler import _LRScheduler
from
torch.optim.optimizer
import
Optimizer
from
torch.utils.data
import
DataLoader
from
colossalai.accelerator
import
get_accelerator
from
colossalai.context
import
Config
,
ConfigException
from
colossalai.interface
import
OptimizerWrapper
from
colossalai.legacy.amp
import
AMP_TYPE
,
convert_to_amp
...
...
@@ -34,7 +35,6 @@ from colossalai.legacy.utils import is_using_ddp, is_using_pp, is_using_sequence
from
colossalai.legacy.zero
import
ShardedOptimizerV2
,
convert_to_zero_v2
from
colossalai.legacy.zero.gemini.ophooks
import
BaseOpHook
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
get_current_device
def
get_default_parser
():
...
...
@@ -309,9 +309,9 @@ def initialize(
else
:
if
isinstance
(
model
,
nn
.
Module
):
# first sync model across dp ranks
model
.
to
(
get_current_device
())
model
.
to
(
get_
accelerator
().
get_
current_device
())
elif
isinstance
(
model
,
Callable
):
model
=
model
().
to
(
get_current_device
())
model
=
model
().
to
(
get_
accelerator
().
get_
current_device
())
# optimizer maybe a optimizer_cls
if
isinstance
(
optimizer
,
Callable
):
...
...
colossalai/legacy/nn/layer/colossalai_layer/embedding.py
View file @
8823cc48
...
...
@@ -3,8 +3,8 @@ from typing import Callable
from
torch
import
dtype
,
nn
from
colossalai.accelerator
import
get_accelerator
from
colossalai.nn
import
init
from
colossalai.utils
import
get_current_device
from
..parallel_1d
import
Embedding1D
,
PatchEmbedding1D
,
VocabParallelEmbedding1D
from
..parallel_2d
import
Embedding2D
,
PatchEmbedding2D
,
VocabParallelEmbedding2D
...
...
@@ -83,7 +83,7 @@ class Embedding(ColossalaiModule):
embed
=
(
nn
.
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
padding_idx
,
*
args
,
**
kwargs
)
.
to
(
dtype
)
.
to
(
get_current_device
())
.
to
(
get_
accelerator
().
get_
current_device
())
)
weight_initializer
(
embed
.
weight
,
fan_in
=
num_embeddings
,
fan_out
=
embedding_dim
)
elif
num_embeddings
<=
vocab_parallel_limit
:
...
...
colossalai/legacy/nn/layer/colossalai_layer/normalization.py
View file @
8823cc48
from
torch
import
nn
from
colossalai.
utils
import
get_
current_device
from
colossalai.
accelerator
import
get_
accelerator
from
..parallel_1d
import
LayerNorm1D
from
..parallel_2d
import
LayerNorm2D
...
...
@@ -36,7 +36,7 @@ class LayerNorm(ColossalaiModule):
def
__init__
(
self
,
normalized_shape
:
int
,
eps
=
1e-05
,
bias
=
True
,
dtype
=
None
)
->
None
:
tensor_parallel
=
get_tensor_parallel_mode
()
if
tensor_parallel
is
None
:
norm
=
nn
.
LayerNorm
(
normalized_shape
,
eps
=
eps
).
to
(
dtype
).
to
(
get_current_device
())
norm
=
nn
.
LayerNorm
(
normalized_shape
,
eps
=
eps
).
to
(
dtype
).
to
(
get_
accelerator
().
get_
current_device
())
else
:
norm
=
_parallel_layernorm
[
tensor_parallel
](
normalized_shape
,
eps
=
eps
,
dtype
=
dtype
)
super
().
__init__
(
norm
)
colossalai/legacy/nn/layer/parallel_1d/layers.py
View file @
8823cc48
...
...
@@ -10,7 +10,7 @@ import torch.nn.functional as F
from
torch
import
Tensor
from
torch.nn.parameter
import
Parameter
from
colossalai.
kernel
import
LayerN
or
m
from
colossalai.
accelerator
import
get_accelerat
or
from
colossalai.legacy.communication
import
broadcast
from
colossalai.legacy.context
import
ParallelMode
,
seed
from
colossalai.legacy.context.parallel_context
import
global_context
as
gpc
...
...
@@ -22,7 +22,7 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict
,
)
from
colossalai.nn
import
init
as
init
from
colossalai.
utils.device
import
get_current_device
from
colossalai.
nn.layer.layernorm
import
MixedFusedLayerNorm
as
LayerNorm
from
..base_layer
import
ParallelLayer
from
..colossalai_layer._utils
import
ColossalaiModule
...
...
@@ -221,7 +221,7 @@ class Classifier1D(ParallelLayer):
# Parameters.
# Initialize weight.
factory_kwargs
=
{
"device"
:
get_current_device
(),
"dtype"
:
dtype
}
factory_kwargs
=
{
"device"
:
get_accelerator
().
get_current_device
(),
"dtype"
:
dtype
}
if
weight
is
not
None
:
self
.
weight
=
weight
self
.
has_weight
=
False
...
...
@@ -357,7 +357,7 @@ class VocabParallelClassifier1D(ParallelLayer):
# Parameters.
# Initialize weight.
factory_kwargs
=
{
"device"
:
get_current_device
(),
"dtype"
:
dtype
}
factory_kwargs
=
{
"device"
:
get_accelerator
().
get_current_device
(),
"dtype"
:
dtype
}
if
weight
is
not
None
:
self
.
weight
=
weight
self
.
has_weight
=
False
...
...
@@ -499,7 +499,7 @@ class Linear1D_Col(ParallelLayer):
# Parameters.
# Initialize weight.
factory_kwargs
=
{
"device"
:
get_current_device
(),
"dtype"
:
dtype
}
factory_kwargs
=
{
"device"
:
get_accelerator
().
get_current_device
(),
"dtype"
:
dtype
}
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
out_features_per_partition
,
self
.
in_features
,
**
factory_kwargs
))
if
bias
:
...
...
@@ -638,7 +638,7 @@ class Linear1D_Row(ParallelLayer):
# Parameters.
# Initialize weight.
factory_kwargs
=
{
"device"
:
get_current_device
(),
"dtype"
:
dtype
}
factory_kwargs
=
{
"device"
:
get_accelerator
().
get_current_device
(),
"dtype"
:
dtype
}
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
out_features
,
self
.
input_size_per_partition
,
**
factory_kwargs
))
if
self
.
stream_chunk_num
>
1
:
...
...
@@ -802,7 +802,9 @@ class Embedding1D(ParallelLayer):
self
.
embed_kwargs
=
kwargs
self
.
weight
=
Parameter
(
torch
.
empty
((
num_embeddings
,
embed_dim_per_partition
),
device
=
get_current_device
(),
dtype
=
dtype
)
torch
.
empty
(
(
num_embeddings
,
embed_dim_per_partition
),
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
)
)
self
.
reset_parameters
(
weight_initializer
)
...
...
@@ -912,7 +914,11 @@ class VocabParallelEmbedding1D(ParallelLayer):
self
.
vocab_end_index
=
self
.
vocab_start_index
+
self
.
num_embeddings_per_partition
self
.
weight
=
Parameter
(
torch
.
empty
((
self
.
num_embeddings_per_partition
,
self
.
embed_dim
),
device
=
get_current_device
(),
dtype
=
dtype
)
torch
.
empty
(
(
self
.
num_embeddings_per_partition
,
self
.
embed_dim
),
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
,
)
)
self
.
reset_parameters
(
weight_initializer
)
...
...
colossalai/legacy/nn/layer/parallel_2d/_operation.py
View file @
8823cc48
...
...
@@ -5,10 +5,10 @@ import torch.distributed as dist
from
torch
import
Tensor
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
colossalai.accelerator
import
get_accelerator
from
colossalai.legacy.communication.collective
import
all_gather
,
all_reduce
,
reduce_scatter
from
colossalai.legacy.context.parallel_mode
import
ParallelMode
from
colossalai.legacy.core
import
global_context
as
gpc
from
colossalai.utils
import
get_current_device
def
matmul_2d
(
...
...
@@ -250,7 +250,7 @@ class Matmul_AB_2D(torch.autograd.Function):
B_shape
=
B
.
shape
B
=
B
.
reshape
((
-
1
,
B_shape
[
-
1
]))
C_shape
=
(
A
.
shape
[
0
],
B
.
shape
[
-
1
])
C
=
torch
.
zeros
(
C_shape
,
dtype
=
A
.
dtype
,
device
=
get_current_device
())
C
=
torch
.
zeros
(
C_shape
,
dtype
=
A
.
dtype
,
device
=
get_
accelerator
().
get_
current_device
())
# use circular buffer to store the communication tensor
# 2 is enough for all cases
...
...
@@ -399,7 +399,7 @@ class Matmul_ABT_2D(torch.autograd.Function):
B_shape
=
B
.
shape
B
=
B
.
reshape
((
-
1
,
B_shape
[
-
1
]))
C_shape
=
(
A
.
shape
[
0
],
B
.
shape
[
0
])
C
=
torch
.
empty
(
C_shape
,
dtype
=
A
.
dtype
,
device
=
get_current_device
())
C
=
torch
.
empty
(
C_shape
,
dtype
=
A
.
dtype
,
device
=
get_
accelerator
().
get_
current_device
())
# use circular buffer to store the communication tensor
# 2 is enough for all cases
...
...
@@ -556,7 +556,7 @@ class Matmul_ATB_2D(torch.autograd.Function):
B_shape
=
B
.
shape
B
=
B
.
reshape
((
-
1
,
B_shape
[
-
1
]))
C_shape
=
(
A
.
shape
[
-
1
],
B
.
shape
[
-
1
])
C
=
torch
.
empty
(
C_shape
,
dtype
=
A
.
dtype
,
device
=
get_current_device
())
C
=
torch
.
empty
(
C_shape
,
dtype
=
A
.
dtype
,
device
=
get_
accelerator
().
get_
current_device
())
# use circular buffer to store the communication tensor
# 2 is enough for all cases
...
...
colossalai/legacy/nn/layer/parallel_2d/layers.py
View file @
8823cc48
...
...
@@ -8,6 +8,7 @@ import torch.nn.functional as F
from
torch
import
Tensor
from
torch.nn
import
Parameter
from
colossalai.accelerator
import
get_accelerator
from
colossalai.legacy.communication
import
broadcast
from
colossalai.legacy.context
import
ParallelMode
,
seed
from
colossalai.legacy.core
import
global_context
as
gpc
...
...
@@ -18,7 +19,6 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict
,
)
from
colossalai.nn
import
init
as
init
from
colossalai.utils.device
import
get_current_device
from
..base_layer
import
ParallelLayer
from
..utils
import
divide
,
set_tensor_parallel_attribute_by_partition
,
to_2tuple
...
...
@@ -82,7 +82,7 @@ class Linear2D(ParallelLayer):
self
.
hidden_size_per_partition
=
divide
(
self
.
out_features
,
self
.
summa_dim
)
# create weight, shape: [k/q, h/q]
factory_kwargs
=
{
"device"
:
get_current_device
(),
"dtype"
:
dtype
}
factory_kwargs
=
{
"device"
:
get_accelerator
().
get_current_device
(),
"dtype"
:
dtype
}
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
input_size_per_partition
,
self
.
hidden_size_per_partition
,
**
factory_kwargs
)
)
...
...
@@ -259,7 +259,7 @@ class LayerNorm2D(ParallelLayer):
self
.
partitioned_partition
=
divide
(
normalized_shape
,
self
.
summa_dim
**
2
)
# create parameters
factory_kwargs
=
{
"device"
:
get_current_device
(),
"dtype"
:
dtype
}
factory_kwargs
=
{
"device"
:
get_accelerator
().
get_current_device
(),
"dtype"
:
dtype
}
self
.
weight
=
Parameter
(
torch
.
ones
(
self
.
partitioned_partition
,
**
factory_kwargs
))
if
bias
:
...
...
@@ -438,18 +438,24 @@ class PatchEmbedding2D(ParallelLayer):
self
.
weight
=
Parameter
(
torch
.
empty
(
(
self
.
embed_size_per_partition
,
in_chans
,
*
self
.
patch_size
),
device
=
get_current_device
(),
device
=
get_
accelerator
().
get_
current_device
(),
dtype
=
dtype
,
)
)
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
embed_size_per_partition
,
device
=
get_current_device
(),
dtype
=
dtype
))
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
embed_size_per_partition
,
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
)
)
self
.
cls_token
=
Parameter
(
torch
.
zeros
((
1
,
1
,
self
.
embed_size_per_partition
),
device
=
get_current_device
(),
dtype
=
dtype
)
torch
.
zeros
(
(
1
,
1
,
self
.
embed_size_per_partition
),
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
)
)
self
.
pos_embed
=
Parameter
(
torch
.
zeros
(
(
1
,
self
.
num_patches
+
1
,
self
.
embed_size_per_partition
),
device
=
get_current_device
(),
dtype
=
dtype
(
1
,
self
.
num_patches
+
1
,
self
.
embed_size_per_partition
),
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
,
)
)
...
...
@@ -619,7 +625,9 @@ class Embedding2D(ParallelLayer):
self
.
embed_kwargs
=
kwargs
self
.
weight
=
Parameter
(
torch
.
empty
((
num_embeddings
,
embed_dim_per_partition
),
device
=
get_current_device
(),
dtype
=
dtype
)
torch
.
empty
(
(
num_embeddings
,
embed_dim_per_partition
),
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
)
)
self
.
reset_parameters
(
weight_initializer
)
...
...
@@ -758,7 +766,7 @@ class VocabParallelEmbedding2D(ParallelLayer):
self
.
weight
=
Parameter
(
torch
.
empty
(
(
self
.
num_embeddings_per_partition
,
self
.
embed_dim_per_partition
),
device
=
get_current_device
(),
device
=
get_
accelerator
().
get_
current_device
(),
dtype
=
dtype
,
)
)
...
...
@@ -895,11 +903,18 @@ class Classifier2D(ParallelLayer):
self
.
has_weight
=
False
else
:
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
num_classes
,
self
.
input_size_per_partition
,
device
=
get_current_device
(),
dtype
=
dtype
)
torch
.
empty
(
self
.
num_classes
,
self
.
input_size_per_partition
,
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
,
)
)
self
.
has_weight
=
True
if
bias
:
self
.
bias
=
Parameter
(
torch
.
zeros
(
self
.
num_classes
,
device
=
get_current_device
(),
dtype
=
dtype
))
self
.
bias
=
Parameter
(
torch
.
zeros
(
self
.
num_classes
,
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
)
)
else
:
self
.
bias
=
None
...
...
@@ -1052,7 +1067,7 @@ class VocabParallelClassifier2D(ParallelLayer):
self
.
output_size_per_partition
=
divide
(
num_classes
,
self
.
summa_dim
)
# create weight, shape: [k/q, h/q]
factory_kwargs
=
{
"device"
:
get_current_device
(),
"dtype"
:
dtype
}
factory_kwargs
=
{
"device"
:
get_accelerator
().
get_current_device
(),
"dtype"
:
dtype
}
if
weight
is
not
None
:
self
.
weight
=
weight
self
.
has_weight
=
False
...
...
colossalai/legacy/nn/layer/parallel_2p5d/_operation.py
View file @
8823cc48
...
...
@@ -5,10 +5,10 @@ import torch.distributed as dist
from
torch
import
Tensor
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
colossalai.accelerator
import
get_accelerator
from
colossalai.legacy.communication.collective
import
all_gather
,
all_reduce
,
reduce_scatter
from
colossalai.legacy.context.parallel_mode
import
ParallelMode
from
colossalai.legacy.core
import
global_context
as
gpc
from
colossalai.utils
import
get_current_device
def
get_parallel_group
(
parallel_mode
:
ParallelMode
):
...
...
@@ -205,7 +205,7 @@ class Matmul_AB_2p5D(torch.autograd.Function):
B_shape
=
B
.
shape
B
=
B
.
reshape
((
-
1
,
B_shape
[
-
1
]))
C_shape
=
(
A
.
shape
[
0
],
B
.
shape
[
-
1
])
C
=
torch
.
zeros
(
C_shape
,
dtype
=
A
.
dtype
,
device
=
get_current_device
())
C
=
torch
.
zeros
(
C_shape
,
dtype
=
A
.
dtype
,
device
=
get_
accelerator
().
get_
current_device
())
# use circular buffer to store the communication tensor
# 2 is enough for all cases
...
...
@@ -362,7 +362,7 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
B_shape
=
B
.
shape
B
=
B
.
reshape
((
-
1
,
B_shape
[
-
1
]))
C_shape
=
(
A
.
shape
[
0
],
B
.
shape
[
0
])
C
=
torch
.
empty
(
C_shape
,
dtype
=
A
.
dtype
,
device
=
get_current_device
())
C
=
torch
.
empty
(
C_shape
,
dtype
=
A
.
dtype
,
device
=
get_
accelerator
().
get_
current_device
())
# use circular buffer to store the communication tensor
# 2 is enough for all cases
...
...
@@ -527,7 +527,7 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
B_shape
=
B
.
shape
B
=
B
.
reshape
((
-
1
,
B_shape
[
-
1
]))
C_shape
=
(
A
.
shape
[
-
1
],
B
.
shape
[
-
1
])
C
=
torch
.
empty
(
C_shape
,
dtype
=
A
.
dtype
,
device
=
get_current_device
())
C
=
torch
.
empty
(
C_shape
,
dtype
=
A
.
dtype
,
device
=
get_
accelerator
().
get_
current_device
())
# use circular buffer to store the communication tensor
# 2 is enough for all cases
...
...
@@ -661,7 +661,9 @@ class _Add_Bias_2p5D(torch.autograd.Function):
if
row_rank
==
0
:
bias_temp
=
bias
.
clone
()
else
:
bias_temp
=
torch
.
zeros
(
output_size_per_partition
,
dtype
=
bias
.
dtype
,
device
=
get_current_device
())
bias_temp
=
torch
.
zeros
(
output_size_per_partition
,
dtype
=
bias
.
dtype
,
device
=
get_accelerator
().
get_current_device
()
)
src_rank
=
(
col_rank
+
dep_rank
*
tesseract_dim
**
2
...
...
@@ -984,7 +986,7 @@ class SplitFirst(torch.autograd.Function):
@
custom_bwd
def
backward
(
ctx
:
Any
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
grad_shape
=
(
ctx
.
batch_size
,)
+
output_grad
.
shape
[
1
:]
grad
=
torch
.
empty
(
grad_shape
,
dtype
=
output_grad
.
dtype
,
device
=
get_current_device
())
grad
=
torch
.
empty
(
grad_shape
,
dtype
=
output_grad
.
dtype
,
device
=
get_
accelerator
().
get_
current_device
())
dist
.
all_gather
(
list
(
grad
.
chunk
(
ctx
.
tesseract_dim
,
dim
=
0
)),
output_grad
.
contiguous
(),
group
=
gpc
.
get_group
(
ctx
.
para_mode
)
)
...
...
colossalai/legacy/nn/layer/parallel_2p5d/layers.py
View file @
8823cc48
...
...
@@ -8,6 +8,7 @@ import torch.nn.functional as F
from
torch
import
Tensor
from
torch.nn
import
Parameter
from
colossalai.accelerator
import
get_accelerator
from
colossalai.legacy.communication
import
broadcast
from
colossalai.legacy.context
import
ParallelMode
,
seed
from
colossalai.legacy.core
import
global_context
as
gpc
...
...
@@ -19,7 +20,6 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict
,
)
from
colossalai.nn
import
init
as
init
from
colossalai.utils.device
import
get_current_device
from
..base_layer
import
ParallelLayer
from
..utils
import
divide
,
set_tensor_parallel_attribute_by_partition
,
to_2tuple
...
...
@@ -84,7 +84,7 @@ class Linear2p5D(ParallelLayer):
self
.
hidden_size_per_partition
=
divide
(
out_features
,
self
.
tesseract_dim
)
# create weight, shape: [k/q, h/q]
factory_kwargs
=
{
"device"
:
get_current_device
(),
"dtype"
:
dtype
}
factory_kwargs
=
{
"device"
:
get_accelerator
().
get_current_device
(),
"dtype"
:
dtype
}
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
input_size_per_partition
,
self
.
hidden_size_per_partition
,
**
factory_kwargs
)
)
...
...
@@ -272,7 +272,7 @@ class LayerNorm2p5D(ParallelLayer):
self
.
partitioned_partition
=
divide
(
normalized_shape
,
self
.
tesseract_dim
)
# *
# create parameters
factory_kwargs
=
{
"device"
:
get_current_device
(),
"dtype"
:
dtype
}
factory_kwargs
=
{
"device"
:
get_accelerator
().
get_current_device
(),
"dtype"
:
dtype
}
self
.
weight
=
Parameter
(
torch
.
ones
(
self
.
partitioned_partition
,
**
factory_kwargs
))
if
bias
:
...
...
@@ -451,18 +451,24 @@ class PatchEmbedding2p5D(ParallelLayer):
self
.
weight
=
Parameter
(
torch
.
empty
(
(
self
.
embed_size_per_partition
,
in_chans
,
*
self
.
patch_size
),
device
=
get_current_device
(),
device
=
get_
accelerator
().
get_
current_device
(),
dtype
=
dtype
,
)
)
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
embed_size_per_partition
,
device
=
get_current_device
(),
dtype
=
dtype
))
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
embed_size_per_partition
,
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
)
)
self
.
cls_token
=
Parameter
(
torch
.
zeros
((
1
,
1
,
self
.
embed_size_per_partition
),
device
=
get_current_device
(),
dtype
=
dtype
)
torch
.
zeros
(
(
1
,
1
,
self
.
embed_size_per_partition
),
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
)
)
self
.
pos_embed
=
Parameter
(
torch
.
zeros
(
(
1
,
self
.
num_patches
+
1
,
self
.
embed_size_per_partition
),
device
=
get_current_device
(),
dtype
=
dtype
(
1
,
self
.
num_patches
+
1
,
self
.
embed_size_per_partition
),
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
,
)
)
...
...
@@ -632,7 +638,9 @@ class Embedding2p5D(ParallelLayer):
self
.
embed_kwargs
=
kwargs
self
.
weight
=
Parameter
(
torch
.
empty
((
num_embeddings
,
embed_dim_per_partition
),
device
=
get_current_device
(),
dtype
=
dtype
)
torch
.
empty
(
(
num_embeddings
,
embed_dim_per_partition
),
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
)
)
self
.
reset_parameters
(
weight_initializer
)
...
...
@@ -772,7 +780,7 @@ class VocabParallelEmbedding2p5D(ParallelLayer):
self
.
weight
=
Parameter
(
torch
.
empty
(
(
self
.
num_embeddings_per_partition
,
self
.
embed_dim_per_partition
),
device
=
get_current_device
(),
device
=
get_
accelerator
().
get_
current_device
(),
dtype
=
dtype
,
)
)
...
...
@@ -910,11 +918,18 @@ class Classifier2p5D(ParallelLayer):
self
.
has_weight
=
False
else
:
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
num_classes
,
self
.
input_size_per_partition
,
device
=
get_current_device
(),
dtype
=
dtype
)
torch
.
empty
(
self
.
num_classes
,
self
.
input_size_per_partition
,
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
,
)
)
self
.
has_weight
=
True
if
bias
:
self
.
bias
=
Parameter
(
torch
.
zeros
(
self
.
num_classes
,
device
=
get_current_device
(),
dtype
=
dtype
))
self
.
bias
=
Parameter
(
torch
.
zeros
(
self
.
num_classes
,
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
)
)
else
:
self
.
bias
=
None
...
...
@@ -1068,7 +1083,7 @@ class VocabParallelClassifier2p5D(ParallelLayer):
self
.
hidden_size_per_partition
=
divide
(
num_classes
,
self
.
tesseract_dim
)
# create weight, shape: [k/q, h/q]
factory_kwargs
=
{
"device"
:
get_current_device
(),
"dtype"
:
dtype
}
factory_kwargs
=
{
"device"
:
get_accelerator
().
get_current_device
(),
"dtype"
:
dtype
}
if
weight
is
not
None
:
self
.
weight
=
weight
self
.
has_weight
=
False
...
...
colossalai/legacy/nn/layer/parallel_3d/layers.py
View file @
8823cc48
...
...
@@ -8,6 +8,7 @@ import torch.nn.functional as F
from
torch
import
Tensor
from
torch.nn
import
Parameter
from
colossalai.accelerator
import
get_accelerator
from
colossalai.legacy.communication
import
all_reduce
,
broadcast
from
colossalai.legacy.constants
import
(
INPUT_GROUP_3D
,
...
...
@@ -27,7 +28,6 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict
,
)
from
colossalai.nn
import
init
as
init
from
colossalai.utils.device
import
get_current_device
from
..utils
import
divide
,
set_tensor_parallel_attribute_by_partition
,
to_2tuple
from
._operation
import
(
...
...
@@ -69,11 +69,13 @@ class LayerNorm3D(ParallelLayer):
self
.
normalized_shape_per_partition
=
divide
(
normalized_shape
,
self
.
depth
)
self
.
weight
=
Parameter
(
torch
.
ones
(
self
.
normalized_shape_per_partition
,
device
=
get_current_device
(),
dtype
=
dtype
)
torch
.
ones
(
self
.
normalized_shape_per_partition
,
device
=
get_
accelerator
().
get_
current_device
(),
dtype
=
dtype
)
)
if
bias
:
self
.
bias
=
Parameter
(
torch
.
zeros
(
self
.
normalized_shape_per_partition
,
device
=
get_current_device
(),
dtype
=
dtype
)
torch
.
zeros
(
self
.
normalized_shape_per_partition
,
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
)
)
else
:
self
.
bias
=
None
...
...
@@ -202,13 +204,15 @@ class Linear3D(ParallelLayer):
torch
.
empty
(
self
.
in_features_per_partition
,
self
.
out_features_per_partition
,
device
=
get_current_device
(),
device
=
get_
accelerator
().
get_
current_device
(),
dtype
=
dtype
,
)
)
if
bias
:
self
.
bias
=
Parameter
(
torch
.
zeros
(
self
.
bias_features_per_partition
,
device
=
get_current_device
(),
dtype
=
dtype
)
torch
.
zeros
(
self
.
bias_features_per_partition
,
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
)
)
else
:
self
.
bias
=
None
...
...
@@ -380,11 +384,18 @@ class Classifier3D(ParallelLayer):
self
.
has_weight
=
False
else
:
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
num_classes
,
self
.
in_features_per_partition
,
device
=
get_current_device
(),
dtype
=
dtype
)
torch
.
empty
(
self
.
num_classes
,
self
.
in_features_per_partition
,
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
,
)
)
self
.
has_weight
=
True
if
bias
:
self
.
bias
=
Parameter
(
torch
.
zeros
(
self
.
num_classes
,
device
=
get_current_device
(),
dtype
=
dtype
))
self
.
bias
=
Parameter
(
torch
.
zeros
(
self
.
num_classes
,
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
)
)
else
:
self
.
bias
=
None
...
...
@@ -523,14 +534,16 @@ class VocabParallelClassifier3D(ParallelLayer):
torch
.
empty
(
self
.
out_features_per_partition
,
self
.
in_features_per_partition
,
device
=
get_current_device
(),
device
=
get_
accelerator
().
get_
current_device
(),
dtype
=
dtype
,
)
)
self
.
has_weight
=
True
if
bias
:
self
.
bias
=
Parameter
(
torch
.
zeros
(
self
.
bias_features_per_partition
,
device
=
get_current_device
(),
dtype
=
dtype
)
torch
.
zeros
(
self
.
bias_features_per_partition
,
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
)
)
else
:
self
.
bias
=
None
...
...
@@ -705,16 +718,24 @@ class PatchEmbedding3D(ParallelLayer):
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
(
(
embed_size_per_partition
,
in_chans
,
*
self
.
patch_size
),
device
=
get_current_device
(),
dtype
=
dtype
(
embed_size_per_partition
,
in_chans
,
*
self
.
patch_size
),
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
,
)
)
self
.
bias
=
nn
.
Parameter
(
torch
.
empty
(
embed_size_per_partition
,
device
=
get_current_device
(),
dtype
=
dtype
))
self
.
bias
=
nn
.
Parameter
(
torch
.
empty
(
embed_size_per_partition
,
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
)
)
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
((
1
,
1
,
embed_size_per_partition
),
device
=
get_current_device
(),
dtype
=
dtype
)
torch
.
zeros
((
1
,
1
,
embed_size_per_partition
),
device
=
get_
accelerator
().
get_
current_device
(),
dtype
=
dtype
)
)
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
((
1
,
self
.
num_patches
+
1
,
embed_size_per_partition
),
device
=
get_current_device
(),
dtype
=
dtype
)
torch
.
zeros
(
(
1
,
self
.
num_patches
+
1
,
embed_size_per_partition
),
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
,
)
)
self
.
reset_parameters
(
weight_initializer
,
bias_initializer
,
position_embed_initializer
)
...
...
@@ -880,7 +901,9 @@ class Embedding3D(ParallelLayer):
self
.
embed_kwargs
=
kwargs
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
((
num_embeddings
,
embed_dim_per_partition
),
device
=
get_current_device
(),
dtype
=
dtype
)
torch
.
empty
(
(
num_embeddings
,
embed_dim_per_partition
),
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
)
)
self
.
reset_parameters
(
weight_initializer
)
...
...
@@ -1019,7 +1042,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
self
.
weight
=
Parameter
(
torch
.
empty
(
(
self
.
num_embeddings_per_partition
,
self
.
embed_dim_per_partition
),
device
=
get_current_device
(),
device
=
get_
accelerator
().
get_
current_device
(),
dtype
=
dtype
,
)
)
...
...
colossalai/legacy/nn/layer/parallel_sequence/_operation.py
View file @
8823cc48
...
...
@@ -5,11 +5,11 @@ import torch
from
torch
import
distributed
as
dist
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
colossalai.accelerator
import
get_accelerator
from
colossalai.legacy.communication
import
ring_forward
from
colossalai.legacy.context.parallel_mode
import
ParallelMode
from
colossalai.legacy.core
import
global_context
as
gpc
from
colossalai.legacy.nn.layer.parallel_sequence._utils
import
_calc_current_device_range
,
_calc_incoming_device_range
from
colossalai.utils
import
get_current_device
class
RingQK
(
torch
.
autograd
.
Function
):
...
...
@@ -30,7 +30,7 @@ class RingQK(torch.autograd.Function):
sub_seq_length
,
sub_seq_length
*
gpc
.
get_world_size
(
ParallelMode
.
SEQUENCE
),
dtype
=
sub_q
.
dtype
,
device
=
get_current_device
(),
device
=
get_
accelerator
().
get_
current_device
(),
)
# compute local QK^T
...
...
@@ -71,7 +71,7 @@ class RingQK(torch.autograd.Function):
grad_q
=
torch
.
zeros_like
(
sub_q
,
dtype
=
sub_q
.
dtype
,
device
=
get_current_device
(),
device
=
get_
accelerator
().
get_
current_device
(),
)
# compute with local sub_k
...
...
@@ -105,7 +105,7 @@ class RingAV(torch.autograd.Function):
batch_size
*
num_attention_heads
,
sub_seq_length
,
attention_head_size
,
device
=
get_current_device
(),
device
=
get_
accelerator
().
get_
current_device
(),
dtype
=
attention_score
.
dtype
,
)
...
...
@@ -142,7 +142,9 @@ class RingAV(torch.autograd.Function):
grad_v
/=
local_world_size
# calculate gradient for attention score
grad_attention_score
=
torch
.
zeros_like
(
attention_scores
,
dtype
=
grad_output
.
dtype
,
device
=
get_current_device
())
grad_attention_score
=
torch
.
zeros_like
(
attention_scores
,
dtype
=
grad_output
.
dtype
,
device
=
get_accelerator
().
get_current_device
()
)
# compute with local sub_k
grad_attention_score
[:,
:,
local_start_idx
:
local_end_idx
]
+=
torch
.
matmul
(
grad_output
,
sub_v
.
transpose
(
2
,
1
))
...
...
colossalai/legacy/nn/layer/parallel_sequence/layers.py
View file @
8823cc48
...
...
@@ -8,13 +8,12 @@ import torch.nn as nn
import
torch.nn.functional
as
F
from
torch.nn
import
Parameter
from
colossalai.kernel
import
FusedScaleMaskSoftmax
from
colossalai.kernel.cuda_native.scaled_softmax
import
AttnMaskType
from
colossalai.legacy.context
import
seed
from
colossalai.legacy.context.parallel_mode
import
ParallelMode
from
colossalai.legacy.core
import
global_context
as
gpc
from
colossalai.legacy.nn.layer.parallel_sequence._operation
import
RingAV
,
RingQK
from
colossalai.legacy.registry
import
LAYERS
from
colossalai.nn.layer.scaled_softmax
import
AttnMaskType
,
FusedScaleMaskSoftmax
@
LAYERS
.
register_module
...
...
colossalai/legacy/nn/layer/vanilla/layers.py
View file @
8823cc48
...
...
@@ -7,10 +7,10 @@ from torch import Tensor
from
torch
import
nn
as
nn
from
torch.nn.parameter
import
Parameter
from
colossalai.accelerator
import
get_accelerator
from
colossalai.legacy.context
import
seed
from
colossalai.legacy.registry
import
LAYERS
from
colossalai.nn
import
init
as
init
from
colossalai.utils.device
import
get_current_device
from
..utils
import
to_2tuple
...
...
@@ -173,12 +173,18 @@ class VanillaPatchEmbedding(nn.Module):
self
.
flatten
=
flatten
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
((
embed_size
,
in_chans
,
*
self
.
patch_size
),
device
=
get_current_device
(),
dtype
=
dtype
)
torch
.
empty
(
(
embed_size
,
in_chans
,
*
self
.
patch_size
),
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
)
)
self
.
bias
=
nn
.
Parameter
(
torch
.
empty
(
embed_size
,
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
))
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
((
1
,
1
,
embed_size
),
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
)
)
self
.
bias
=
nn
.
Parameter
(
torch
.
empty
(
embed_size
,
device
=
get_current_device
(),
dtype
=
dtype
))
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
((
1
,
1
,
embed_size
),
device
=
get_current_device
(),
dtype
=
dtype
))
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
((
1
,
self
.
num_patches
+
1
,
embed_size
),
device
=
get_current_device
(),
dtype
=
dtype
)
torch
.
zeros
(
(
1
,
self
.
num_patches
+
1
,
embed_size
),
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
)
)
self
.
reset_parameters
(
weight_initializer
,
bias_initializer
,
position_embed_initializer
)
...
...
@@ -242,11 +248,15 @@ class VanillaClassifier(nn.Module):
self
.
has_weight
=
False
else
:
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_classes
,
self
.
in_features
,
device
=
get_current_device
(),
dtype
=
dtype
)
torch
.
empty
(
self
.
num_classes
,
self
.
in_features
,
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
)
)
self
.
has_weight
=
True
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_classes
,
device
=
get_current_device
(),
dtype
=
dtype
))
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_classes
,
device
=
get_accelerator
().
get_current_device
(),
dtype
=
dtype
)
)
else
:
self
.
bias
=
None
...
...
@@ -287,7 +297,7 @@ class VanillaLayerNorm(nn.Module):
self
.
normalized_shape
=
(
normalized_shape
,)
self
.
variance_epsilon
=
eps
factory_kwargs
=
{
"device"
:
get_current_device
(),
"dtype"
:
dtype
}
factory_kwargs
=
{
"device"
:
get_accelerator
().
get_current_device
(),
"dtype"
:
dtype
}
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
normalized_shape
,
**
factory_kwargs
))
if
bias
:
...
...
@@ -333,7 +343,7 @@ class VanillaLinear(nn.Module):
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
skip_bias_add
=
skip_bias_add
factory_kwargs
=
{
"device"
:
get_current_device
(),
"dtype"
:
dtype
}
factory_kwargs
=
{
"device"
:
get_accelerator
().
get_current_device
(),
"dtype"
:
dtype
}
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
out_features
,
self
.
in_features
,
**
factory_kwargs
))
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
out_features
,
**
factory_kwargs
))
...
...
colossalai/legacy/nn/loss/loss_2d.py
View file @
8823cc48
...
...
@@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd
from
torch.nn.functional
import
cross_entropy
from
torch.nn.modules.loss
import
_Loss
from
colossalai.accelerator
import
get_accelerator
from
colossalai.legacy.context
import
ParallelMode
from
colossalai.legacy.core
import
global_context
as
gpc
from
colossalai.legacy.nn.layer.parallel_2d
import
reduce_by_batch_2d
,
split_batch_2d
from
colossalai.legacy.nn.layer.parallel_2d._utils
import
assert_summa_initialization
from
colossalai.legacy.registry
import
LOSSES
from
colossalai.utils
import
get_current_device
@
LOSSES
.
register_module
...
...
@@ -118,7 +118,7 @@ class _VocabParallelCrossEntropy2D(torch.autograd.Function):
grad_2d
=
grad_input
.
view
(
-
1
,
partition_vocab_size
)
# Add the gradient from matching classes.
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
grad_2d
.
size
()[
0
],
device
=
get_current_device
())
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
grad_2d
.
size
()[
0
],
device
=
get_
accelerator
().
get_
current_device
())
grad_2d
[
arange_1d
,
masked_target
]
-=
1.0
-
target_mask
.
view
(
-
1
).
float
()
# Finally elementwise multiplication with the output gradients.
...
...
colossalai/legacy/nn/loss/loss_2p5d.py
View file @
8823cc48
...
...
@@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd
from
torch.nn.functional
import
cross_entropy
from
torch.nn.modules.loss
import
_Loss
from
colossalai.accelerator
import
get_accelerator
from
colossalai.legacy.context
import
ParallelMode
from
colossalai.legacy.core
import
global_context
as
gpc
from
colossalai.legacy.nn.layer.parallel_2p5d
import
reduce_by_batch_2p5d
,
split_batch_2p5d
from
colossalai.legacy.nn.layer.parallel_2p5d._utils
import
assert_tesseract_initialization
from
colossalai.legacy.registry
import
LOSSES
from
colossalai.utils
import
get_current_device
@
LOSSES
.
register_module
...
...
@@ -112,7 +112,7 @@ class _VocabParallelCrossEntropy2p5D(torch.autograd.Function):
grad_2d
=
grad_input
.
view
(
-
1
,
partition_vocab_size
)
# Add the gradient from matching classes.
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
grad_2d
.
size
()[
0
],
device
=
get_current_device
())
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
grad_2d
.
size
()[
0
],
device
=
get_
accelerator
().
get_
current_device
())
grad_2d
[
arange_1d
,
masked_target
]
-=
1.0
-
target_mask
.
view
(
-
1
).
float
()
# Finally elementwise multiplication with the output gradients.
...
...
Prev
1
2
3
4
5
6
7
8
9
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment