Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c2947dad
Unverified
Commit
c2947dad
authored
Nov 10, 2022
by
Jiarui Fang
Committed by
GitHub
Nov 10, 2022
Browse files
[inference] streaming Linear 1D Row inference (#1874)
parent
a1416812
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
629 additions
and
554 deletions
+629
-554
colossalai/nn/layer/parallel_1d/layers.py
colossalai/nn/layer/parallel_1d/layers.py
+21
-5
tests/test_fx/test_complete_workflow.py
tests/test_fx/test_complete_workflow.py
+10
-7
tests/test_layers/test_1d/checks_1d/check_layer_1d.py
tests/test_layers/test_1d/checks_1d/check_layer_1d.py
+549
-496
tests/test_layers/test_1d/test_1d.py
tests/test_layers/test_1d/test_1d.py
+49
-46
No files found.
colossalai/nn/layer/parallel_1d/layers.py
View file @
c2947dad
...
@@ -597,9 +597,12 @@ class Linear1D_Row(ParallelLayer):
...
@@ -597,9 +597,12 @@ class Linear1D_Row(ParallelLayer):
parallel_input
:
bool
=
True
,
parallel_input
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
)):
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
),
stream_chunk_num
:
int
=
1
):
super
().
__init__
()
super
().
__init__
()
self
.
stream_chunk_num
=
stream_chunk_num
# Keep input parameters
# Keep input parameters
self
.
in_features
=
in_features
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
out_features
=
out_features
...
@@ -617,6 +620,9 @@ class Linear1D_Row(ParallelLayer):
...
@@ -617,6 +620,9 @@ class Linear1D_Row(ParallelLayer):
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
out_features
,
self
.
input_size_per_partition
,
**
factory_kwargs
))
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
out_features
,
self
.
input_size_per_partition
,
**
factory_kwargs
))
if
self
.
stream_chunk_num
>
1
:
# TODO() work for inference only
self
.
chunk_weight
()
if
bias
:
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
out_features
,
**
factory_kwargs
))
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
out_features
,
**
factory_kwargs
))
else
:
else
:
...
@@ -626,6 +632,9 @@ class Linear1D_Row(ParallelLayer):
...
@@ -626,6 +632,9 @@ class Linear1D_Row(ParallelLayer):
self
.
_set_tensor_parallel_attributes
()
self
.
_set_tensor_parallel_attributes
()
set_parallel_input
(
False
)
set_parallel_input
(
False
)
def
chunk_weight
(
self
):
self
.
weight_list
=
torch
.
chunk
(
self
.
weight
,
self
.
stream_chunk_num
,
dim
=
0
)
def
reset_parameters
(
self
,
weight_initializer
,
bias_initializer
)
->
None
:
def
reset_parameters
(
self
,
weight_initializer
,
bias_initializer
)
->
None
:
fan_in
,
fan_out
=
self
.
in_features
,
self
.
out_features
fan_in
,
fan_out
=
self
.
in_features
,
self
.
out_features
weight_initializer
(
self
.
weight
,
fan_in
=
fan_in
,
fan_out
=
fan_out
)
weight_initializer
(
self
.
weight
,
fan_in
=
fan_in
,
fan_out
=
fan_out
)
...
@@ -696,10 +705,17 @@ class Linear1D_Row(ParallelLayer):
...
@@ -696,10 +705,17 @@ class Linear1D_Row(ParallelLayer):
input_
.
shape
,
self
.
weight
.
shape
,
self
.
weight
.
shape
[
-
1
]
*
gpc
.
tensor_parallel_size
)
input_
.
shape
,
self
.
weight
.
shape
,
self
.
weight
.
shape
[
-
1
]
*
gpc
.
tensor_parallel_size
)
input_
=
split_forward_gather_backward
(
input_
,
ParallelMode
.
PARALLEL_1D
,
dim
=-
1
)
input_
=
split_forward_gather_backward
(
input_
,
ParallelMode
.
PARALLEL_1D
,
dim
=-
1
)
if
self
.
stream_chunk_num
>
1
:
output_parallel_list
=
[
None
for
i
in
range
(
self
.
stream_chunk_num
)]
for
i
in
range
(
self
.
stream_chunk_num
):
output_parallel_list
[
i
]
=
F
.
linear
(
input_
,
self
.
weight_list
[
i
])
output_parallel_list
[
i
]
=
reduce_input
(
output_parallel_list
[
i
],
ParallelMode
.
PARALLEL_1D
)
output
=
torch
.
cat
(
output_parallel_list
,
dim
=-
1
)
else
:
print
(
input_
.
shape
,
self
.
weight
.
shape
)
output_parallel
=
F
.
linear
(
input_
,
self
.
weight
)
output_parallel
=
F
.
linear
(
input_
,
self
.
weight
)
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
output
=
reduce_input
(
output_parallel
,
ParallelMode
.
PARALLEL_1D
)
output
=
reduce_input
(
output_parallel
,
ParallelMode
.
PARALLEL_1D
)
if
not
self
.
skip_bias_add
:
if
not
self
.
skip_bias_add
:
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
output
=
output
+
self
.
bias
output
=
output
+
self
.
bias
...
...
tests/test_fx/test_complete_workflow.py
View file @
c2947dad
...
@@ -32,7 +32,7 @@ class MLP(torch.nn.Module):
...
@@ -32,7 +32,7 @@ class MLP(torch.nn.Module):
return
x
return
x
def
run_workflow
(
world_size
):
def
run_workflow
(
world_size
,
dev
):
# initailization
# initailization
with
LazyInitContext
()
as
ctx
:
with
LazyInitContext
()
as
ctx
:
model
=
MLP
(
16
)
model
=
MLP
(
16
)
...
@@ -46,7 +46,7 @@ def run_workflow(world_size):
...
@@ -46,7 +46,7 @@ def run_workflow(world_size):
gm
=
torch
.
fx
.
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
=
torch
.
fx
.
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
# annotate
# annotate
annotated_gm
=
transformer_mlp_pass
(
gm
,
process_group
=
ProcessGroup
())
annotated_gm
=
transformer_mlp_pass
(
gm
,
process_group
=
ProcessGroup
(
tp_degree
=
world_size
))
annotated_gm
.
recompile
()
annotated_gm
.
recompile
()
# materialization and sharding
# materialization and sharding
...
@@ -61,22 +61,25 @@ def run_workflow(world_size):
...
@@ -61,22 +61,25 @@ def run_workflow(world_size):
# test forward to make sure that IR transform will produce the same results
# test forward to make sure that IR transform will produce the same results
# like how ColoTensor would do it normally
# like how ColoTensor would do it normally
data
=
torch
.
rand
(
4
,
16
)
data
=
torch
.
rand
(
4
,
16
,
device
=
dev
)
non_fx_out
=
model
(
data
)
non_fx_out
=
model
(
data
)
fx_out
=
annotated_gm
(
data
)
fx_out
=
annotated_gm
(
data
)
assert
torch
.
equal
(
non_fx_out
,
fx_out
),
f
'
{
non_fx_out
}
vs
{
fx_out
}
'
assert
torch
.
equal
(
non_fx_out
,
fx_out
),
f
'
{
non_fx_out
}
vs
{
fx_out
}
'
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
dev
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_workflow
(
world_size
)
run_workflow
(
world_size
,
dev
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'dev'
,
[
'cuda'
,
'cpu'
])
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
def
test_complete_workflow
(
world_size
):
def
test_complete_workflow
(
world_size
,
dev
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
if
dev
==
'cpu'
and
world_size
>
1
:
return
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
dev
=
dev
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
...
tests/test_layers/test_1d/checks_1d/check_layer_1d.py
View file @
c2947dad
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.nn
import
Parameter
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.global_variables
import
tensor_parallel_env
as
env
from
colossalai.global_variables
import
tensor_parallel_env
as
env
from
colossalai.nn
import
(
Classifier1D
,
Embedding1D
,
Linear1D_Col
,
Linear1D_Row
,
VanillaClassifier
,
from
colossalai.nn
import
(
VocabParallelClassifier1D
,
VocabParallelCrossEntropyLoss1D
,
VocabParallelEmbedding1D
)
Classifier1D
,
Embedding1D
,
Linear1D_Col
,
Linear1D_Row
,
VanillaClassifier
,
VocabParallelClassifier1D
,
VocabParallelCrossEntropyLoss1D
,
VocabParallelEmbedding1D
,
)
from
colossalai.utils
import
get_current_device
,
print_rank_0
from
colossalai.utils
import
get_current_device
,
print_rank_0
from
torch.nn
import
Parameter
from
.common
import
BATCH_SIZE
,
DEPTH
,
HIDDEN_SIZE
,
NUM_CLASSES
,
SEQ_LENGTH
,
VOCAB_SIZE
,
check_equal
from
.common
import
BATCH_SIZE
,
DEPTH
,
HIDDEN_SIZE
,
NUM_CLASSES
,
SEQ_LENGTH
,
VOCAB_SIZE
,
check_equal
...
@@ -494,3 +503,47 @@ def check_vocab_parallel_loss():
...
@@ -494,3 +503,47 @@ def check_vocab_parallel_loss():
out_grad
=
torch
.
chunk
(
out_grad
,
DEPTH
,
dim
=-
1
)[
i
]
out_grad
=
torch
.
chunk
(
out_grad
,
DEPTH
,
dim
=-
1
)[
i
]
check_equal
(
out_grad
,
out
.
grad
)
check_equal
(
out_grad
,
out
.
grad
)
print_rank_0
(
'vocab parallel loss backward: pass'
)
print_rank_0
(
'vocab parallel loss backward: pass'
)
@
torch
.
no_grad
()
def
check_linear_row_stream_inference
():
device
=
get_current_device
()
dtype
=
torch
.
float32
INPUT_SIZE
=
HIDDEN_SIZE
OUTPUT_SIZE
=
2
*
HIDDEN_SIZE
i
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
assert
HIDDEN_SIZE
%
2
==
0
layer
=
Linear1D_Row
(
OUTPUT_SIZE
,
INPUT_SIZE
,
stream_chunk_num
=
2
)
A_shape
=
(
BATCH_SIZE
,
SEQ_LENGTH
,
OUTPUT_SIZE
)
A_master
=
torch
.
randn
(
A_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
A_master
,
src
=
0
)
A
=
torch
.
chunk
(
A_master
,
DEPTH
,
dim
=-
1
)[
i
]
A
=
A
.
clone
()
W_shape
=
(
INPUT_SIZE
,
OUTPUT_SIZE
)
W_master
=
torch
.
randn
(
W_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
W_master
,
src
=
0
)
W
=
torch
.
chunk
(
W_master
,
DEPTH
,
dim
=-
1
)[
i
]
W
=
W
.
clone
()
B_shape
=
(
INPUT_SIZE
)
B_master
=
torch
.
randn
(
B_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
B_master
,
src
=
0
)
B
=
B_master
.
clone
()
layer
.
weight
=
Parameter
(
W
)
layer
.
bias
=
Parameter
(
B
)
layer
.
chunk_weight
()
out
=
layer
(
A
)
A_master
=
A_master
.
clone
()
W_master
=
W_master
.
clone
()
B_master
=
B_master
.
clone
()
C_master
=
torch
.
matmul
(
A_master
,
W_master
.
transpose
(
0
,
1
))
+
B_master
C
=
C_master
.
clone
()
check_equal
(
out
,
C
)
print_rank_0
(
'linear_row forward: pass'
)
tests/test_layers/test_1d/test_1d.py
View file @
c2947dad
...
@@ -6,12 +6,13 @@ from functools import partial
...
@@ -6,12 +6,13 @@ from functools import partial
import
pytest
import
pytest
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
checks_1d.check_layer_1d
import
*
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.initialize
import
launch
from
colossalai.initialize
import
launch
from
colossalai.
utils
import
free_port
from
colossalai.
logging
import
disable_existing_loggers
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
c
hecks_1d.check_layer_1d
import
*
from
c
olossalai.utils
import
free_port
CONFIG
=
dict
(
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
4
,
mode
=
'1d'
)),)
CONFIG
=
dict
(
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
4
,
mode
=
'1d'
)),)
...
@@ -30,6 +31,8 @@ def check_layer(rank, world_size, port):
...
@@ -30,6 +31,8 @@ def check_layer(rank, world_size, port):
check_vocab_parallel_classifier_given_embed_weight
()
check_vocab_parallel_classifier_given_embed_weight
()
check_vocab_parallel_loss
()
check_vocab_parallel_loss
()
check_linear_row_stream_inference
()
gpc
.
destroy
()
gpc
.
destroy
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment