Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c2947dad
Unverified
Commit
c2947dad
authored
Nov 10, 2022
by
Jiarui Fang
Committed by
GitHub
Nov 10, 2022
Browse files
[inference] streaming Linear 1D Row inference (#1874)
parent
a1416812
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
629 additions
and
554 deletions
+629
-554
colossalai/nn/layer/parallel_1d/layers.py
colossalai/nn/layer/parallel_1d/layers.py
+21
-5
tests/test_fx/test_complete_workflow.py
tests/test_fx/test_complete_workflow.py
+10
-7
tests/test_layers/test_1d/checks_1d/check_layer_1d.py
tests/test_layers/test_1d/checks_1d/check_layer_1d.py
+549
-496
tests/test_layers/test_1d/test_1d.py
tests/test_layers/test_1d/test_1d.py
+49
-46
No files found.
colossalai/nn/layer/parallel_1d/layers.py
View file @
c2947dad
...
@@ -597,9 +597,12 @@ class Linear1D_Row(ParallelLayer):
...
@@ -597,9 +597,12 @@ class Linear1D_Row(ParallelLayer):
parallel_input
:
bool
=
True
,
parallel_input
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
)):
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
),
stream_chunk_num
:
int
=
1
):
super
().
__init__
()
super
().
__init__
()
self
.
stream_chunk_num
=
stream_chunk_num
# Keep input parameters
# Keep input parameters
self
.
in_features
=
in_features
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
out_features
=
out_features
...
@@ -617,6 +620,9 @@ class Linear1D_Row(ParallelLayer):
...
@@ -617,6 +620,9 @@ class Linear1D_Row(ParallelLayer):
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
out_features
,
self
.
input_size_per_partition
,
**
factory_kwargs
))
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
out_features
,
self
.
input_size_per_partition
,
**
factory_kwargs
))
if
self
.
stream_chunk_num
>
1
:
# TODO() work for inference only
self
.
chunk_weight
()
if
bias
:
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
out_features
,
**
factory_kwargs
))
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
out_features
,
**
factory_kwargs
))
else
:
else
:
...
@@ -626,6 +632,9 @@ class Linear1D_Row(ParallelLayer):
...
@@ -626,6 +632,9 @@ class Linear1D_Row(ParallelLayer):
self
.
_set_tensor_parallel_attributes
()
self
.
_set_tensor_parallel_attributes
()
set_parallel_input
(
False
)
set_parallel_input
(
False
)
def
chunk_weight
(
self
):
self
.
weight_list
=
torch
.
chunk
(
self
.
weight
,
self
.
stream_chunk_num
,
dim
=
0
)
def
reset_parameters
(
self
,
weight_initializer
,
bias_initializer
)
->
None
:
def
reset_parameters
(
self
,
weight_initializer
,
bias_initializer
)
->
None
:
fan_in
,
fan_out
=
self
.
in_features
,
self
.
out_features
fan_in
,
fan_out
=
self
.
in_features
,
self
.
out_features
weight_initializer
(
self
.
weight
,
fan_in
=
fan_in
,
fan_out
=
fan_out
)
weight_initializer
(
self
.
weight
,
fan_in
=
fan_in
,
fan_out
=
fan_out
)
...
@@ -696,10 +705,17 @@ class Linear1D_Row(ParallelLayer):
...
@@ -696,10 +705,17 @@ class Linear1D_Row(ParallelLayer):
input_
.
shape
,
self
.
weight
.
shape
,
self
.
weight
.
shape
[
-
1
]
*
gpc
.
tensor_parallel_size
)
input_
.
shape
,
self
.
weight
.
shape
,
self
.
weight
.
shape
[
-
1
]
*
gpc
.
tensor_parallel_size
)
input_
=
split_forward_gather_backward
(
input_
,
ParallelMode
.
PARALLEL_1D
,
dim
=-
1
)
input_
=
split_forward_gather_backward
(
input_
,
ParallelMode
.
PARALLEL_1D
,
dim
=-
1
)
output_parallel
=
F
.
linear
(
input_
,
self
.
weight
)
if
self
.
stream_chunk_num
>
1
:
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
output_parallel_list
=
[
None
for
i
in
range
(
self
.
stream_chunk_num
)]
output
=
reduce_input
(
output_parallel
,
ParallelMode
.
PARALLEL_1D
)
for
i
in
range
(
self
.
stream_chunk_num
):
output_parallel_list
[
i
]
=
F
.
linear
(
input_
,
self
.
weight_list
[
i
])
output_parallel_list
[
i
]
=
reduce_input
(
output_parallel_list
[
i
],
ParallelMode
.
PARALLEL_1D
)
output
=
torch
.
cat
(
output_parallel_list
,
dim
=-
1
)
else
:
print
(
input_
.
shape
,
self
.
weight
.
shape
)
output_parallel
=
F
.
linear
(
input_
,
self
.
weight
)
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
output
=
reduce_input
(
output_parallel
,
ParallelMode
.
PARALLEL_1D
)
if
not
self
.
skip_bias_add
:
if
not
self
.
skip_bias_add
:
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
output
=
output
+
self
.
bias
output
=
output
+
self
.
bias
...
...
tests/test_fx/test_complete_workflow.py
View file @
c2947dad
...
@@ -32,7 +32,7 @@ class MLP(torch.nn.Module):
...
@@ -32,7 +32,7 @@ class MLP(torch.nn.Module):
return
x
return
x
def
run_workflow
(
world_size
):
def
run_workflow
(
world_size
,
dev
):
# initailization
# initailization
with
LazyInitContext
()
as
ctx
:
with
LazyInitContext
()
as
ctx
:
model
=
MLP
(
16
)
model
=
MLP
(
16
)
...
@@ -46,7 +46,7 @@ def run_workflow(world_size):
...
@@ -46,7 +46,7 @@ def run_workflow(world_size):
gm
=
torch
.
fx
.
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
=
torch
.
fx
.
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
# annotate
# annotate
annotated_gm
=
transformer_mlp_pass
(
gm
,
process_group
=
ProcessGroup
())
annotated_gm
=
transformer_mlp_pass
(
gm
,
process_group
=
ProcessGroup
(
tp_degree
=
world_size
))
annotated_gm
.
recompile
()
annotated_gm
.
recompile
()
# materialization and sharding
# materialization and sharding
...
@@ -61,22 +61,25 @@ def run_workflow(world_size):
...
@@ -61,22 +61,25 @@ def run_workflow(world_size):
# test forward to make sure that IR transform will produce the same results
# test forward to make sure that IR transform will produce the same results
# like how ColoTensor would do it normally
# like how ColoTensor would do it normally
data
=
torch
.
rand
(
4
,
16
)
data
=
torch
.
rand
(
4
,
16
,
device
=
dev
)
non_fx_out
=
model
(
data
)
non_fx_out
=
model
(
data
)
fx_out
=
annotated_gm
(
data
)
fx_out
=
annotated_gm
(
data
)
assert
torch
.
equal
(
non_fx_out
,
fx_out
),
f
'
{
non_fx_out
}
vs
{
fx_out
}
'
assert
torch
.
equal
(
non_fx_out
,
fx_out
),
f
'
{
non_fx_out
}
vs
{
fx_out
}
'
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
dev
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_workflow
(
world_size
)
run_workflow
(
world_size
,
dev
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'dev'
,
[
'cuda'
,
'cpu'
])
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
def
test_complete_workflow
(
world_size
):
def
test_complete_workflow
(
world_size
,
dev
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
if
dev
==
'cpu'
and
world_size
>
1
:
return
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
dev
=
dev
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
...
tests/test_layers/test_1d/checks_1d/check_layer_1d.py
View file @
c2947dad
This diff is collapsed.
Click to expand it.
tests/test_layers/test_1d/test_1d.py
View file @
c2947dad
#!/usr/bin/env python
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# -*- encoding: utf-8 -*-
from
functools
import
partial
from
functools
import
partial
import
pytest
import
pytest
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
colossalai.core
import
global_context
as
gpc
from
checks_1d.check_layer_1d
import
*
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.initialize
import
launch
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils
import
free_port
from
colossalai.initialize
import
launch
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.logging
import
disable_existing_loggers
from
checks_1d.check_layer_1d
import
*
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
CONFIG
=
dict
(
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
4
,
mode
=
'1d'
)),)
CONFIG
=
dict
(
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
4
,
mode
=
'1d'
)),)
def
check_layer
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
def
check_layer
(
rank
,
world_size
,
port
):
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
disable_existing_loggers
()
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
check_linear_col
()
check_linear_row
()
check_linear_col
()
check_embed
()
check_linear_row
()
check_vocab_parallel_embed
()
check_embed
()
check_classifier_no_given_weight
()
check_vocab_parallel_embed
()
check_vocab_parallel_classifier_no_given_weight
()
check_classifier_no_given_weight
()
check_classifier_given_embed_weight
()
check_vocab_parallel_classifier_no_given_weight
()
check_vocab_parallel_classifier_given_embed_weight
()
check_classifier_given_embed_weight
()
check_vocab_parallel_loss
()
check_vocab_parallel_classifier_given_embed_weight
()
check_vocab_parallel_loss
()
gpc
.
destroy
()
torch
.
cuda
.
empty_cache
()
check_linear_row_stream_inference
()
gpc
.
destroy
()
@
pytest
.
mark
.
dist
torch
.
cuda
.
empty_cache
()
@
rerun_if_address_is_in_use
()
def
test_1d
():
world_size
=
4
@
pytest
.
mark
.
dist
run_func
=
partial
(
check_layer
,
world_size
=
world_size
,
port
=
free_port
())
@
rerun_if_address_is_in_use
()
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
def
test_1d
():
world_size
=
4
run_func
=
partial
(
check_layer
,
world_size
=
world_size
,
port
=
free_port
())
if
__name__
==
'__main__'
:
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
test_1d
()
if
__name__
==
'__main__'
:
test_1d
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment