Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
726541af
Commit
726541af
authored
Aug 01, 2023
by
FoolPlayer
Committed by
Hongxin Liu
Aug 15, 2023
Browse files
update some module with new api version
parent
879301d0
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
89 additions
and
49 deletions
+89
-49
colossalai/shardformer/layer/qkv_fused_linear.py
colossalai/shardformer/layer/qkv_fused_linear.py
+53
-34
colossalai/shardformer/policies/blip2.py
colossalai/shardformer/policies/blip2.py
+1
-1
colossalai/shardformer/policies/chatglm.py
colossalai/shardformer/policies/chatglm.py
+1
-1
colossalai/shardformer/policies/sam.py
colossalai/shardformer/policies/sam.py
+1
-1
colossalai/shardformer/policies/whisper.py
colossalai/shardformer/policies/whisper.py
+1
-1
tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py
...t_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py
+29
-9
tests/test_shardformer/test_model/test_shard_chatglm.py
tests/test_shardformer/test_model/test_shard_chatglm.py
+3
-2
No files found.
colossalai/shardformer/layer/qkv_fused_linear.py
View file @
726541af
...
@@ -537,10 +537,11 @@ class FusedLinear1D_Col(ParallelModule):
...
@@ -537,10 +537,11 @@ class FusedLinear1D_Col(ParallelModule):
gather_output
:
bool
=
False
,
gather_output
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
n_fused
:
int
=
3
,
n_fused
:
int
=
3
,
weight
:
Optional
[
Parameter
]
=
None
,
bias_
:
Optional
[
Parameter
]
=
None
,
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
)):
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
)):
super
().
__init__
()
super
().
__init__
()
# Keep input parameters
# Keep input parameters
self
.
in_features
=
in_features
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
out_features
=
out_features
...
@@ -554,34 +555,50 @@ class FusedLinear1D_Col(ParallelModule):
...
@@ -554,34 +555,50 @@ class FusedLinear1D_Col(ParallelModule):
if
skip_bias_add
and
not
bias
:
if
skip_bias_add
and
not
bias
:
raise
ValueError
(
'cannot skip bias addition if bias is None'
)
raise
ValueError
(
'cannot skip bias addition if bias is None'
)
# offset the seed with randomizer index and rank
seed
=
torch
.
random
.
initial_seed
()
self
.
randomizer
=
create_randomizer_with_offset
(
seed
,
process_group
=
self
.
process_group
)
# sanity check
if
weight
is
not
None
:
assert
not
bias
or
bias_
is
not
None
,
'bias_ must be provided if bias is True when weight is not None'
else
:
assert
bias_
is
None
,
'bias_ must be None if weight is None'
# Parameters.
# Parameters.
if
weight
is
None
:
# Initialize weight.
# Initialize weight.
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
weight
=
torch
.
empty
(
self
.
out_features
,
self
.
in_features
,
**
factory_kwargs
)
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
out_features
,
self
.
in_features
,
**
factory_kwargs
))
else
:
weight
.
data
=
weight
.
data
.
to
(
device
=
device
,
dtype
=
dtype
)
self
.
weight
=
weight
def
shard_fn
(
tensor
):
def
shard_fn
(
tensor
):
return
split_fused_qkv_in_gpt2_style
(
tensor
,
self
.
n_fused
,
self
.
process_group
,
False
)
return
split_fused_qkv_in_gpt2_style
(
tensor
,
self
.
n_fused
,
self
.
process_group
,
False
)
def
gather_fn
(
tensor
):
def
gather_fn
(
tensor
):
return
gather_fused_qkv_in_gpt2_style
(
tensor
,
3
,
self
.
process_group
,
False
)
return
gather_fused_qkv_in_gpt2_style
(
tensor
,
self
.
n_fused
,
self
.
process_group
,
False
)
if
not
is_customized_distributed_tensor
(
self
.
weight
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
sharded_weight
=
distribute_tensor_with_customization
(
weight
,
shard_fn
,
gather_fn
)
sharded_weight
=
distribute_tensor_with_customization
(
self
.
weight
.
data
,
shard_fn
,
gather_fn
)
self
.
weight
=
customized_distributed_tensor_to_param
(
sharded_weight
)
customized_distributed_tensor_to_
existing_
param
(
sharded_weight
,
self
.
weight
)
if
bias
:
if
bias
:
bias
=
torch
.
empty
(
self
.
out_features
,
**
factory_kwargs
)
if
bias_
is
None
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
out_features
,
**
factory_kwargs
))
else
:
bias_
.
data
=
bias_
.
data
.
to
(
device
=
device
,
dtype
=
dtype
)
self
.
bias
=
bias_
if
not
is_customized_distributed_tensor
(
self
.
bias
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
sharded_bias
=
distribute_tensor_with_customization
(
bias
,
shard_fn
,
gather_fn
)
sharded_bias
=
distribute_tensor_with_customization
(
self
.
bias
.
data
,
shard_fn
,
gather_fn
)
self
.
bias
=
customized_distributed_tensor_to_param
(
sharded_bias
)
customized_distributed_tensor_to_
existing_
param
(
sharded_bias
,
self
.
bias
)
else
:
else
:
self
.
bias
=
None
self
.
bias
=
None
# offset the seed with randomizer index and rank
if
weight
is
None
:
seed
=
torch
.
random
.
initial_seed
()
self
.
randomizer
=
create_randomizer_with_offset
(
seed
,
process_group
=
self
.
process_group
)
# init weights
# init weights
self
.
reset_parameters
(
weight_initializer
,
bias_initializer
)
self
.
reset_parameters
(
weight_initializer
,
bias_initializer
)
...
@@ -613,24 +630,26 @@ class FusedLinear1D_Col(ParallelModule):
...
@@ -613,24 +630,26 @@ class FusedLinear1D_Col(ParallelModule):
bias
=
bias
,
bias
=
bias
,
device
=
device
,
device
=
device
,
process_group
=
process_group
,
process_group
=
process_group
,
weight
=
module
.
weight
,
bias_
=
module
.
bias
,
*
args
,
*
args
,
**
kwargs
)
**
kwargs
)
# TODO: copy the sharded weights
#
#
TODO: copy the sharded weights
with
torch
.
no_grad
():
#
with torch.no_grad():
sharded_weight
=
split_fused_qkv_in_gpt2_style
(
module
.
weight
.
data
,
#
sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data,
n_fused
=
n_fused
,
#
n_fused=n_fused,
process_group
=
process_group
,
#
process_group=process_group,
is_transposed
=
False
)
#
is_transposed=False)
linear_1d
.
weight
.
data
.
copy_
(
sharded_weight
.
data
)
#
linear_1d.weight.data.copy_(sharded_weight.data)
if
bias
:
#
if bias:
sharded_bias
=
split_fused_qkv_in_gpt2_style
(
module
.
bias
.
data
,
#
sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data,
n_fused
=
n_fused
,
#
n_fused=n_fused,
process_group
=
process_group
,
#
process_group=process_group,
is_transposed
=
False
)
#
is_transposed=False)
linear_1d
.
bias
.
data
.
copy_
(
sharded_bias
.
data
)
#
linear_1d.bias.data.copy_(sharded_bias.data)
print
(
linear_1d
.
weight
.
shape
)
return
linear_1d
return
linear_1d
def
reset_parameters
(
self
,
weight_initializer
,
bias_initializer
)
->
None
:
def
reset_parameters
(
self
,
weight_initializer
,
bias_initializer
)
->
None
:
...
...
colossalai/shardformer/policies/blip2.py
View file @
726541af
...
@@ -4,7 +4,7 @@ import colossalai.shardformer.layer as col_nn
...
@@ -4,7 +4,7 @@ import colossalai.shardformer.layer as col_nn
from
.._utils
import
getattr_
,
setattr_
from
.._utils
import
getattr_
,
setattr_
from
..modeling.blip2
import
forward_fn
from
..modeling.blip2
import
forward_fn
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
from
.base
_
policy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
__all__
=
[
'BlipPolicy'
,
'BlipModelPolicy'
]
__all__
=
[
'BlipPolicy'
,
'BlipModelPolicy'
]
...
...
colossalai/shardformer/policies/chatglm.py
View file @
726541af
...
@@ -4,7 +4,7 @@ import torch.nn as nn
...
@@ -4,7 +4,7 @@ import torch.nn as nn
import
colossalai.shardformer.layer
as
col_nn
import
colossalai.shardformer.layer
as
col_nn
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
from
.base
_
policy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
__all__
=
[
'ChatGLMModelPolicy'
,
'ChatGLMForConditionalGenerationPolicy'
]
__all__
=
[
'ChatGLMModelPolicy'
,
'ChatGLMForConditionalGenerationPolicy'
]
...
...
colossalai/shardformer/policies/sam.py
View file @
726541af
...
@@ -4,7 +4,7 @@ import colossalai.shardformer.layer as col_nn
...
@@ -4,7 +4,7 @@ import colossalai.shardformer.layer as col_nn
from
.._utils
import
getattr_
,
setattr_
from
.._utils
import
getattr_
,
setattr_
from
..modeling.sam
import
forward_fn
from
..modeling.sam
import
forward_fn
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
from
.base
_
policy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
__all__
=
[
'SamPolicy'
,
'SamModelPolicy'
]
__all__
=
[
'SamPolicy'
,
'SamModelPolicy'
]
...
...
colossalai/shardformer/policies/whisper.py
View file @
726541af
...
@@ -3,7 +3,7 @@ import torch.nn as nn
...
@@ -3,7 +3,7 @@ import torch.nn as nn
import
colossalai.shardformer.layer
as
col_nn
import
colossalai.shardformer.layer
as
col_nn
from
.._utils
import
getattr_
,
setattr_
from
.._utils
import
getattr_
,
setattr_
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
from
.base
_
policy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
__all__
=
[
__all__
=
[
'WhisperPolicy'
,
'WhisperModelPolicy'
,
'WhisperForConditionalGenerationPolicy'
,
'WhisperForAudioClassification'
'WhisperPolicy'
,
'WhisperModelPolicy'
,
'WhisperForConditionalGenerationPolicy'
,
'WhisperForAudioClassification'
...
...
tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py
View file @
726541af
from
contextlib
import
nullcontext
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.testing
import
assert_close
from
torch.testing
import
assert_close
import
colossalai
import
colossalai
from
colossalai.lazy
import
LazyInitContext
from
colossalai.shardformer.layer
import
GPT2FusedLinearConv1D_Col
,
GPT2FusedLinearConv1D_Row
from
colossalai.shardformer.layer
import
GPT2FusedLinearConv1D_Col
,
GPT2FusedLinearConv1D_Row
from
colossalai.shardformer.layer.qkv_fused_linear
import
split_fused_qkv_in_gpt2_style
from
colossalai.shardformer.layer.qkv_fused_linear
import
split_fused_qkv_in_gpt2_style
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
# This code is copied from https://github.com/huggingface/transformers
# This code is copied from https://github.com/huggingface/transformers
...
@@ -50,9 +53,13 @@ def rearrange(tensor: torch.Tensor, dim: int):
...
@@ -50,9 +53,13 @@ def rearrange(tensor: torch.Tensor, dim: int):
return
rearanged_tensor
return
rearanged_tensor
def
check_gpt2_linear_conv_1d_col
():
@
parameterize
(
'lazy_init'
,
[
False
,
True
])
def
check_linear_conv_1d_col
(
lazy_init
:
bool
):
ctx
=
LazyInitContext
()
if
lazy_init
else
nullcontext
()
linear
=
Conv1D
(
192
,
48
).
cuda
()
linear
=
Conv1D
(
192
,
48
).
cuda
()
linear_conv_col
=
GPT2FusedLinearConv1D_Col
.
from_native_module
(
linear
,
with
ctx
:
linear_copy
=
Conv1D
(
192
,
48
).
cuda
()
linear_conv_col
=
GPT2FusedLinearConv1D_Col
.
from_native_module
(
linear_copy
,
process_group
=
None
,
process_group
=
None
,
gather_output
=
True
,
gather_output
=
True
,
n_fused
=
3
)
n_fused
=
3
)
...
@@ -61,6 +68,8 @@ def check_gpt2_linear_conv_1d_col():
...
@@ -61,6 +68,8 @@ def check_gpt2_linear_conv_1d_col():
assert
linear
.
bias
.
shape
==
torch
.
Size
([
192
])
assert
linear
.
bias
.
shape
==
torch
.
Size
([
192
])
assert
linear_conv_col
.
weight
.
shape
==
torch
.
Size
([
48
,
96
])
assert
linear_conv_col
.
weight
.
shape
==
torch
.
Size
([
48
,
96
])
assert
linear_conv_col
.
bias
.
shape
==
torch
.
Size
([
96
])
assert
linear_conv_col
.
bias
.
shape
==
torch
.
Size
([
96
])
assert
linear_copy
.
weight
is
linear_conv_col
.
weight
assert
linear_copy
.
bias
is
linear_conv_col
.
bias
# ensure weights are reversibly loadable
# ensure weights are reversibly loadable
linear_conv_col
.
load_state_dict
(
linear
.
state_dict
())
linear_conv_col
.
load_state_dict
(
linear
.
state_dict
())
...
@@ -80,13 +89,24 @@ def check_gpt2_linear_conv_1d_col():
...
@@ -80,13 +89,24 @@ def check_gpt2_linear_conv_1d_col():
assert_close
(
target_grad
,
linear_conv_col
.
weight
.
grad
)
assert_close
(
target_grad
,
linear_conv_col
.
weight
.
grad
)
def
check_gpt2_linear_conv_1d_row
():
@
parameterize
(
'lazy_init'
,
[
False
,
True
])
def
check_linear_conv_1d_row
(
lazy_init
:
bool
):
ctx
=
LazyInitContext
()
if
lazy_init
else
nullcontext
()
linear
=
Conv1D
(
192
,
48
).
cuda
()
linear
=
Conv1D
(
192
,
48
).
cuda
()
linear_row
=
GPT2FusedLinearConv1D_Row
.
from_native_module
(
linear
,
process_group
=
None
,
parallel_input
=
False
)
with
ctx
:
linear_copy
=
Conv1D
(
192
,
48
).
cuda
()
linear_row
=
GPT2FusedLinearConv1D_Row
.
from_native_module
(
linear_copy
,
process_group
=
None
,
parallel_input
=
False
)
assert
linear
.
weight
.
shape
==
torch
.
Size
([
48
,
192
])
assert
linear
.
weight
.
shape
==
torch
.
Size
([
48
,
192
])
assert
linear_row
.
weight
.
shape
==
torch
.
Size
([
24
,
192
])
assert
linear_row
.
weight
.
shape
==
torch
.
Size
([
24
,
192
])
assert
linear_row
.
bias
.
shape
==
torch
.
Size
([
192
])
assert
linear_row
.
bias
.
shape
==
torch
.
Size
([
192
])
assert
linear_copy
.
weight
is
linear_row
.
weight
assert
linear_copy
.
bias
is
linear_row
.
bias
# ensure weights are reversibly loadable
linear_row
.
load_state_dict
(
linear
.
state_dict
())
linear
.
load_state_dict
(
linear_row
.
state_dict
())
# check computation correctness
# check computation correctness
x
=
torch
.
rand
(
4
,
48
).
cuda
()
x
=
torch
.
rand
(
4
,
48
).
cuda
()
...
@@ -107,14 +127,14 @@ def run_dist(rank, world_size, port):
...
@@ -107,14 +127,14 @@ def run_dist(rank, world_size, port):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
# test for linear conv
# test for linear conv
check_
gpt2_
linear_conv_1d_col
()
check_linear_conv_1d_col
()
check_
gpt2_
linear_conv_1d_row
()
check_linear_conv_1d_row
()
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
def
test_
gpt2_
linearconv
():
def
test_linearconv
():
spawn
(
run_dist
,
nprocs
=
2
)
spawn
(
run_dist
,
nprocs
=
2
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_
gpt2_
linearconv
()
test_linearconv
()
tests/test_shardformer/test_model/test_shard_chatglm.py
View file @
726541af
...
@@ -84,9 +84,10 @@ def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism):
...
@@ -84,9 +84,10 @@ def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism):
model_copy
=
copy
.
deepcopy
(
org_model
)
model_copy
=
copy
.
deepcopy
(
org_model
)
shard_former
=
ShardFormer
(
shard_config
=
shard_config
)
shard_former
=
ShardFormer
(
shard_config
=
shard_config
)
if
name
==
"transformers_chatglm"
:
if
name
==
"transformers_chatglm"
:
sharded_model
=
shard_former
.
optimize
(
model_copy
,
ChatGLMModelPolicy
())
.
cuda
()
sharded_model
,
_
=
shard_former
.
optimize
(
model_copy
,
ChatGLMModelPolicy
())
else
:
else
:
sharded_model
=
shard_former
.
optimize
(
model_copy
,
ChatGLMForConditionalGenerationPolicy
()).
cuda
()
sharded_model
,
_
=
shard_former
.
optimize
(
model_copy
,
ChatGLMForConditionalGenerationPolicy
())
sharded_model
=
sharded_model
.
cuda
()
check_forward_backward
(
org_model
,
sharded_model
,
data_gen_fn
,
output_transform_fn
,
loss_fn
)
check_forward_backward
(
org_model
,
sharded_model
,
data_gen_fn
,
output_transform_fn
,
loss_fn
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment