Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
537a52b7
Commit
537a52b7
authored
May 24, 2023
by
Frank Lee
Committed by
FrankLeeeee
Jun 08, 2023
Browse files
[shardformer] refactored the user api (#3828)
* [shardformer] refactored the user api * polish code
parent
bc19024b
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
35 additions
and
87 deletions
+35
-87
colossalai/shardformer/README.md
colossalai/shardformer/README.md
+3
-3
colossalai/shardformer/shard/__init__.py
colossalai/shardformer/shard/__init__.py
+5
-0
colossalai/shardformer/shard/shard_config.py
colossalai/shardformer/shard/shard_config.py
+2
-0
colossalai/shardformer/shard/sharder.py
colossalai/shardformer/shard/sharder.py
+18
-9
colossalai/shardformer/shard/shardmodel.py
colossalai/shardformer/shard/shardmodel.py
+0
-60
colossalai/shardformer/shard/slicer.py
colossalai/shardformer/shard/slicer.py
+1
-6
colossalai/shardformer/test/test.py
colossalai/shardformer/test/test.py
+6
-9
No files found.
colossalai/shardformer/README.md
View file @
537a52b7
...
...
@@ -18,7 +18,7 @@
The sample API usage is given below:
```
python
from
colossalai.shardformer
.shard.shardmodel
import
S
hard
M
odel
from
colossalai.shardformer
import
s
hard
_m
odel
from
transformers
import
BertForMaskedLM
# create huggingface model as normal
...
...
@@ -26,11 +26,11 @@ model = BertForMaskedLM.from_pretrained("bert-base-uncased")
# make the huggingface model paralleled to ShardModel
# auto policy:
shardmodel
=
S
hard
M
odel
(
model
)
.
model
shard
ed_
model
=
s
hard
_m
odel
(
model
)
# custom policy:
from
xxx
import
<
POLICYCLASS
>
shardmodel
=
S
hard
M
odel
(
model
,
<
POLICYCLASS
>
)
.
model
shard
ed_
model
=
s
hard
_m
odel
(
model
,
<
POLICYCLASS
>
)
# do angthing as normal
...
...
...
colossalai/shardformer/shard/__init__.py
View file @
537a52b7
from
.shard_config
import
ShardConfig
from
.sharder
import
ModelSharder
,
shard_model
from
.slicer
import
Slicer
__all__
=
[
'ShardConfig'
,
'ModelSharder'
,
'shard_model'
,
'Slicer'
]
colossalai/shardformer/shard/shardconfig.py
→
colossalai/shardformer/shard/shard
_
config.py
View file @
537a52b7
from
dataclasses
import
dataclass
__all__
=
[
'ShardConfig'
]
@
dataclass
class
ShardConfig
:
...
...
colossalai/shardformer/shard/sharder.py
View file @
537a52b7
import
os
from
dataclasses
import
dataclass
from
typing
import
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Type
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
import
torch
import
torch.nn
as
nn
import
colossalai.nn
as
col_nn
from
colossalai.logging
import
get_dist_logger
from
..policies.autopolicy
import
get_autopolicy
from
..policies.basepolicy
import
Layer
,
Policy
from
..policies.basepolicy
import
Policy
from
..utils.utils
import
getattr_
,
hasattr_
,
setattr_
from
.shardconfig
import
ShardConfig
from
.shard
_
config
import
ShardConfig
from
.slicer
import
Slicer
logger
=
get_dist_logger
()
__all__
=
[
'ModelSharder'
,
'shard_model'
]
class
ModelSharder
(
object
):
...
...
@@ -245,3 +240,17 @@ class ModelSharder(object):
param
=
nn
.
Parameter
(
param
)
setattr_
(
model
,
k
,
param
)
setattr_
(
model
,
v
,
param
)
def
shard_model
(
model
:
nn
.
Module
,
shard_config
:
ShardConfig
=
None
,
policy
:
Policy
=
None
):
r
"""
The function is used to shard the PyTorch model.
Args:
model (`torch.nn.Model`): the origin huggingface model
shard_config (`ShardConfig`): the config for distribute information
policy (`Policy`): the custom policy for sharding
"""
sharder
=
ModelSharder
(
model
=
model
,
shard_config
=
shard_config
,
policy
=
policy
)
sharder
.
shard
()
return
model
colossalai/shardformer/shard/shardmodel.py
deleted
100644 → 0
View file @
bc19024b
import
os
from
contextlib
import
suppress
from
dataclasses
import
dataclass
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
transformers
from
colossalai.tensor.d_tensor.layout
import
Layout
from
..policies.basepolicy
import
Policy
from
.shardconfig
import
ShardConfig
from
.sharder
import
ModelSharder
class
ShardModel
(
object
):
r
"""
The class for sharding the huggingface model, ''self.model'' is the sharded model
Just creat a new ShardModel object to shard huggingface model
Args:
model (:class:`torch.nn.Model`): the origin huggingface model
dist_config (:class:`ShardConfig`): the config for distribute information
custom_policy (:class:`Policy`): the custom policy for sharding
"""
def
__init__
(
self
,
model
:
nn
.
Module
,
shard_config
:
ShardConfig
=
None
,
# TODO
custom_policy
:
Policy
=
None
,
)
->
None
:
self
.
model
=
model
self
.
shard_config
=
shard_config
self
.
policy
=
custom_policy
# self.layout=, # TODO
sharder
=
ModelSharder
(
model
=
self
.
model
,
policy
=
self
.
policy
,
shard_config
=
self
.
shard_config
,
)
sharder
.
shard
()
def
set_environ
(
self
)
->
None
:
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"true"
os
.
environ
[
"MKL_SERVICE_FORCE_INTEL"
]
=
"GNU"
os
.
environ
[
"MASTER_ADDR"
]
=
str
(
self
.
dist_config
.
master_addr
)
os
.
environ
[
"MASTER_PORT"
]
=
str
(
self
.
dist_config
.
master_port
)
os
.
environ
[
"WORLD_SIZE"
]
=
str
(
self
.
dist_config
.
num_gpus
)
os
.
environ
[
"RANK"
]
=
str
(
self
.
dist_config
.
rank
)
os
.
environ
[
"LOCAL_RANK"
]
=
str
(
self
.
dist_config
.
rank
)
if
not
dist
.
is_initialized
():
dist
.
init_process_group
(
backend
=
self
.
dist_config
.
backend
)
torch
.
cuda
.
set_device
(
int
(
os
.
getenv
(
"LOCAL_RANK"
,
"0"
)))
def
back_to_org
()
->
None
:
pass
colossalai/shardformer/shard/slicer.py
View file @
537a52b7
import
os
from
dataclasses
import
dataclass
from
typing
import
Dict
,
Tuple
import
torch
import
torch.distributed
as
dist
from
..policies.basepolicy
import
Col_Layer
,
Layer
,
Row_Layer
from
.shardconfig
import
ShardConfig
from
.shard
_
config
import
ShardConfig
dim_mapping
=
{
Col_Layer
:
1
,
Row_Layer
:
0
}
...
...
colossalai/shardformer/test/test.py
View file @
537a52b7
import
argparse
import
inspect
import
os
import
torch
...
...
@@ -7,12 +5,10 @@ import torch.nn as nn
from
datasets
import
load_dataset
from
torch.utils.data
import
DataLoader
from
tqdm.auto
import
tqdm
from
transformers
import
AutoTokenizer
,
BertForMaskedLM
,
DataCollatorForLanguageModeling
,
Trainer
,
TrainingArguments
from
transformers
import
AutoTokenizer
,
BertForMaskedLM
,
DataCollatorForLanguageModeling
import
colossalai
from
colossalai.logging
import
get_dist_logger
from
colossalai.shardformer.shard.shardconfig
import
ShardConfig
from
colossalai.shardformer.shard.shardmodel
import
ShardModel
from
colossalai.shardformer.shard
import
ShardConfig
,
shard_model
from
colossalai.utils
import
get_current_device
,
print_rank_0
os
.
environ
[
'TRANSFORMERS_NO_ADVISORY_WARNINGS'
]
=
'true'
...
...
@@ -93,8 +89,9 @@ if __name__ == "__main__":
rank
=
int
(
str
(
get_current_device
()).
split
(
':'
)[
-
1
]),
world_size
=
int
(
os
.
environ
[
'WORLD_SIZE'
]),
)
shardmodel
=
ShardModel
(
model
,
shard_config
)
sharded_model
=
shard_model
(
model
,
shard_config
)
if
args
.
mode
==
"train"
:
train
(
shard
model
.
model
)
train
(
shard
ed_
model
)
elif
args
.
mode
==
"inference"
:
inference
(
shard
model
.
model
)
inference
(
shard
ed_
model
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment