Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
44a190e6
Commit
44a190e6
authored
Jun 30, 2023
by
Frank Lee
Browse files
[shardformer] import huggingface implicitly (#4101)
parent
6a88bae4
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
91 additions
and
38 deletions
+91
-38
colossalai/shardformer/policies/autopolicy.py
colossalai/shardformer/policies/autopolicy.py
+2
-0
colossalai/shardformer/policies/basepolicy.py
colossalai/shardformer/policies/basepolicy.py
+2
-0
colossalai/shardformer/policies/bert.py
colossalai/shardformer/policies/bert.py
+21
-9
colossalai/shardformer/policies/gpt2.py
colossalai/shardformer/policies/gpt2.py
+12
-2
colossalai/shardformer/policies/llama.py
colossalai/shardformer/policies/llama.py
+9
-3
colossalai/shardformer/policies/opt.py
colossalai/shardformer/policies/opt.py
+9
-8
colossalai/shardformer/policies/t5.py
colossalai/shardformer/policies/t5.py
+14
-13
colossalai/shardformer/policies/vit.py
colossalai/shardformer/policies/vit.py
+5
-2
colossalai/shardformer/shard/shard_config.py
colossalai/shardformer/shard/shard_config.py
+17
-1
No files found.
colossalai/shardformer/policies/autopolicy.py
View file @
44a190e6
...
@@ -5,6 +5,8 @@ import torch.nn as nn
...
@@ -5,6 +5,8 @@ import torch.nn as nn
from
.basepolicy
import
Policy
from
.basepolicy
import
Policy
__all__
=
[
"PolicyLocation"
,
"get_autopolicy"
,
"import_policy"
]
@
dataclass
@
dataclass
class
PolicyLocation
:
class
PolicyLocation
:
...
...
colossalai/shardformer/policies/basepolicy.py
View file @
44a190e6
...
@@ -8,6 +8,8 @@ import torch.nn as nn
...
@@ -8,6 +8,8 @@ import torch.nn as nn
from
..shard.shard_config
import
ShardConfig
from
..shard.shard_config
import
ShardConfig
__all__
=
[
"ParallelModule"
,
"SubModuleReplacementDescription"
,
"ModulePolicyDescription"
,
"Policy"
]
class
ParallelModule
():
class
ParallelModule
():
...
...
colossalai/shardformer/policies/bert.py
View file @
44a190e6
import
torch.nn
as
nn
import
torch.nn
as
nn
from
transformers.models.bert.modeling_bert
import
(
BertEmbeddings
,
BertForMultipleChoice
,
BertForSequenceClassification
,
BertForTokenClassification
,
BertLayer
,
BertLMPredictionHead
,
)
import
colossalai.shardformer.layer
as
col_nn
import
colossalai.shardformer.layer
as
col_nn
from
.._utils
import
getattr_
,
setattr_
from
.._utils
import
getattr_
,
setattr_
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
__all__
=
[
'BertPolicy'
,
'BertModelPolicy'
,
'BertForPretrainingPolicy'
,
'BertLMHeadModelPolicy'
,
'BertForMaskedLMPolicy'
,
'BertForNextSentencePredictionPolicy'
,
'BertForSequenceClassificationPolicy'
,
'BertForTokenClassificationPolicy'
,
'BertForMultipleChoicePolicy'
]
class
BertPolicy
(
Policy
):
class
BertPolicy
(
Policy
):
...
@@ -33,6 +31,8 @@ class BertPolicy(Policy):
...
@@ -33,6 +31,8 @@ class BertPolicy(Policy):
return
self
.
model
return
self
.
model
def
module_policy
(
self
):
def
module_policy
(
self
):
from
transformers.models.bert.modeling_bert
import
BertEmbeddings
,
BertLayer
base_policy
=
{
base_policy
=
{
BertLayer
:
BertLayer
:
ModulePolicyDescription
(
ModulePolicyDescription
(
...
@@ -123,7 +123,7 @@ class BertPolicy(Policy):
...
@@ -123,7 +123,7 @@ class BertPolicy(Policy):
def
new_model_class
(
self
):
def
new_model_class
(
self
):
# do nothing
# do nothing
return
self
.
model
return
None
def
postprocess
(
self
):
def
postprocess
(
self
):
return
self
.
model
return
self
.
model
...
@@ -143,6 +143,8 @@ class BertForPretrainingPolicy(BertPolicy):
...
@@ -143,6 +143,8 @@ class BertForPretrainingPolicy(BertPolicy):
super
().
__init__
()
super
().
__init__
()
def
module_policy
(
self
):
def
module_policy
(
self
):
from
transformers.models.bert.modeling_bert
import
BertLMPredictionHead
module_policy
=
super
().
module_policy
()
module_policy
=
super
().
module_policy
()
addon_module
=
{
addon_module
=
{
BertLMPredictionHead
:
BertLMPredictionHead
:
...
@@ -184,6 +186,8 @@ class BertLMHeadModelPolicy(BertPolicy):
...
@@ -184,6 +186,8 @@ class BertLMHeadModelPolicy(BertPolicy):
super
().
__init__
()
super
().
__init__
()
def
module_policy
(
self
):
def
module_policy
(
self
):
from
transformers.models.bert.modeling_bert
import
BertLMPredictionHead
module_policy
=
super
().
module_policy
()
module_policy
=
super
().
module_policy
()
addon_module
=
{
addon_module
=
{
BertLMPredictionHead
:
BertLMPredictionHead
:
...
@@ -221,6 +225,8 @@ class BertForMaskedLMPolicy(BertPolicy):
...
@@ -221,6 +225,8 @@ class BertForMaskedLMPolicy(BertPolicy):
super
().
__init__
()
super
().
__init__
()
def
module_policy
(
self
):
def
module_policy
(
self
):
from
transformers.models.bert.modeling_bert
import
BertLMPredictionHead
module_policy
=
super
().
module_policy
()
module_policy
=
super
().
module_policy
()
addon_module
=
{
addon_module
=
{
BertLMPredictionHead
:
BertLMPredictionHead
:
...
@@ -261,6 +267,8 @@ class BertForSequenceClassificationPolicy(BertPolicy):
...
@@ -261,6 +267,8 @@ class BertForSequenceClassificationPolicy(BertPolicy):
super
().
__init__
()
super
().
__init__
()
def
module_policy
(
self
):
def
module_policy
(
self
):
from
transformers.models.bert.modeling_bert
import
BertForSequenceClassification
module_policy
=
super
().
module_policy
()
module_policy
=
super
().
module_policy
()
addon_module
=
{
addon_module
=
{
BertForSequenceClassification
:
BertForSequenceClassification
:
...
@@ -284,6 +292,8 @@ class BertForTokenClassificationPolicy(BertPolicy):
...
@@ -284,6 +292,8 @@ class BertForTokenClassificationPolicy(BertPolicy):
super
().
__init__
()
super
().
__init__
()
def
module_policy
(
self
):
def
module_policy
(
self
):
from
transformers.models.bert.modeling_bert
import
BertForTokenClassification
module_policy
=
super
().
module_policy
()
module_policy
=
super
().
module_policy
()
addon_module
=
{
addon_module
=
{
BertForTokenClassification
:
BertForTokenClassification
:
...
@@ -314,6 +324,8 @@ class BertForMultipleChoicePolicy(BertPolicy):
...
@@ -314,6 +324,8 @@ class BertForMultipleChoicePolicy(BertPolicy):
super
().
__init__
()
super
().
__init__
()
def
module_policy
(
self
):
def
module_policy
(
self
):
from
transformers.models.bert.modeling_bert
import
BertForMultipleChoice
module_policy
=
super
().
module_policy
()
module_policy
=
super
().
module_policy
()
addon_module
=
{
addon_module
=
{
BertForMultipleChoice
:
BertForMultipleChoice
:
...
...
colossalai/shardformer/policies/gpt2.py
View file @
44a190e6
import
torch.nn
as
nn
import
torch.nn
as
nn
from
transformers.models.gpt2.modeling_gpt2
import
GPT2Block
,
GPT2DoubleHeadsModel
,
GPT2LMHeadModel
,
GPT2Model
import
colossalai.shardformer.layer
as
col_nn
import
colossalai.shardformer.layer
as
col_nn
from
.._utils
import
getattr_
,
setattr_
from
.._utils
import
getattr_
,
setattr_
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
__all__
=
[
'GPT2Policy'
,
'GPT2ModelPolicy'
,
'GPT2LMHeadModelPolicy'
,
'GPT2DoubleHeadsModelPolicy'
,
'GPT2ForTokenClassificationPolicy'
,
'GPT2ForSequenceClassificationPolicy'
]
class
GPT2Policy
(
Policy
):
class
GPT2Policy
(
Policy
):
...
@@ -25,7 +29,9 @@ class GPT2Policy(Policy):
...
@@ -25,7 +29,9 @@ class GPT2Policy(Policy):
return
self
.
model
return
self
.
model
def
module_policy
(
self
):
def
module_policy
(
self
):
base_policy
=
{
from
transformers.models.gpt2.modeling_gpt2
import
GPT2Block
,
GPT2Model
return
{
GPT2Model
:
GPT2Model
:
ModulePolicyDescription
(
attribute_replacement
=
{},
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
param_replacement
=
[],
...
@@ -125,6 +131,8 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
...
@@ -125,6 +131,8 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
super
().
__init__
()
super
().
__init__
()
def
module_policy
(
self
):
def
module_policy
(
self
):
from
transformers.models.gpt2.modeling_gpt2
import
GPT2LMHeadModel
module_policy
=
super
().
module_policy
()
module_policy
=
super
().
module_policy
()
addon_module
=
{
addon_module
=
{
GPT2LMHeadModel
:
GPT2LMHeadModel
:
...
@@ -156,6 +164,8 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
...
@@ -156,6 +164,8 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
super
().
__init__
()
super
().
__init__
()
def
module_policy
(
self
):
def
module_policy
(
self
):
from
transformers.models.gpt2.modeling_gpt2
import
GPT2DoubleHeadsModel
module_policy
=
super
().
module_policy
()
module_policy
=
super
().
module_policy
()
addon_module
=
{
addon_module
=
{
GPT2DoubleHeadsModel
:
GPT2DoubleHeadsModel
:
...
...
colossalai/shardformer/policies/llama.py
View file @
44a190e6
from
typing
import
Dict
,
Union
from
typing
import
Dict
,
Union
import
torch.nn
as
nn
import
torch.nn
as
nn
from
transformers
import
LlamaForCausalLM
,
LlamaForSequenceClassification
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaModel
from
colossalai.shardformer.layer
import
FusedRMSNorm
,
Linear1D_Col
,
Linear1D_Row
,
VocabParallelEmbedding1D
from
colossalai.shardformer.layer
import
FusedRMSNorm
,
Linear1D_Col
,
Linear1D_Row
,
VocabParallelEmbedding1D
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
__all__
=
[
'LlamaPolicy'
,
'LlamaForCausalLMPolicy'
,
'LlamaForSequenceClassificationPolicy'
]
class
LlamaPolicy
(
Policy
):
class
LlamaPolicy
(
Policy
):
...
@@ -26,7 +26,9 @@ class LlamaPolicy(Policy):
...
@@ -26,7 +26,9 @@ class LlamaPolicy(Policy):
return
self
.
model
return
self
.
model
def
module_policy
(
self
)
->
Dict
[
Union
[
str
,
nn
.
Module
],
ModulePolicyDescription
]:
def
module_policy
(
self
)
->
Dict
[
Union
[
str
,
nn
.
Module
],
ModulePolicyDescription
]:
base_policy
=
{
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaModel
return
{
LlamaDecoderLayer
:
LlamaDecoderLayer
:
ModulePolicyDescription
(
ModulePolicyDescription
(
attribute_replacement
=
{
attribute_replacement
=
{
...
@@ -109,6 +111,8 @@ class LlamaPolicy(Policy):
...
@@ -109,6 +111,8 @@ class LlamaPolicy(Policy):
class
LlamaForCausalLMPolicy
(
LlamaPolicy
):
class
LlamaForCausalLMPolicy
(
LlamaPolicy
):
def
module_policy
(
self
):
def
module_policy
(
self
):
from
transformers
import
LlamaForCausalLM
policy
=
super
().
module_policy
()
policy
=
super
().
module_policy
()
# add a new item for casual lm
# add a new item for casual lm
new_item
=
{
new_item
=
{
...
@@ -128,6 +132,8 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
...
@@ -128,6 +132,8 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
class
LlamaForSequenceClassificationPolicy
(
LlamaPolicy
):
class
LlamaForSequenceClassificationPolicy
(
LlamaPolicy
):
def
module_policy
(
self
):
def
module_policy
(
self
):
from
transformers
import
LlamaForSequenceClassification
policy
=
super
().
module_policy
()
policy
=
super
().
module_policy
()
# add a new item for sequence classification
# add a new item for sequence classification
...
...
colossalai/shardformer/policies/opt.py
View file @
44a190e6
from
transformers.models.opt.modeling_opt
import
(
OPTAttention
,
OPTDecoder
,
OPTDecoderLayer
,
OPTForCausalLM
,
OPTForSequenceClassification
,
)
from
colossalai.shardformer.layer
import
Embedding1D
,
FusedLayerNorm
,
Linear1D_Col
,
Linear1D_Row
from
colossalai.shardformer.layer
import
Embedding1D
,
FusedLayerNorm
,
Linear1D_Col
,
Linear1D_Row
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
__all__
=
[
'OPTPolicy'
,
'OPTModelPolicy'
,
'OPTForCausalLMPolicy'
,
'OPTForSequenceClassificationPolicy'
,
'OPTForQuestionAnsweringPolicy'
]
class
OPTPolicy
(
Policy
):
class
OPTPolicy
(
Policy
):
...
@@ -29,6 +26,8 @@ class OPTPolicy(Policy):
...
@@ -29,6 +26,8 @@ class OPTPolicy(Policy):
return
self
.
model
return
self
.
model
def
module_policy
(
self
):
def
module_policy
(
self
):
from
transformers.models.opt.modeling_opt
import
OPTAttention
,
OPTDecoder
,
OPTDecoderLayer
base_policy
=
{
base_policy
=
{
OPTDecoder
:
OPTDecoder
:
ModulePolicyDescription
(
attribute_replacement
=
{},
ModulePolicyDescription
(
attribute_replacement
=
{},
...
@@ -111,6 +110,8 @@ class OPTModelPolicy(OPTPolicy):
...
@@ -111,6 +110,8 @@ class OPTModelPolicy(OPTPolicy):
class
OPTForCausalLMPolicy
(
OPTPolicy
):
class
OPTForCausalLMPolicy
(
OPTPolicy
):
def
module_policy
(
self
):
def
module_policy
(
self
):
from
transformers.models.opt.modeling_opt
import
OPTForCausalLM
policy
=
super
().
module_policy
()
policy
=
super
().
module_policy
()
new_item
=
{
new_item
=
{
OPTForCausalLM
:
OPTForCausalLM
:
...
...
colossalai/shardformer/policies/t5.py
View file @
44a190e6
from
transformers
import
T5ForConditionalGeneration
from
colossalai.shardformer.layer
import
DropoutForParallelInput
,
Embedding1D
,
Linear1D_Col
,
Linear1D_Row
from
transformers.models.t5.modeling_t5
import
(
T5Attention
,
T5DenseActDense
,
T5DenseGatedActDense
,
T5LayerCrossAttention
,
T5LayerFF
,
T5LayerSelfAttention
,
T5Stack
,
)
from
colossalai.shardformer.layer
import
DropoutForParallelInput
,
Embedding1D
,
FusedRMSNorm
,
Linear1D_Col
,
Linear1D_Row
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
...
@@ -34,7 +23,17 @@ class T5ModelPolicy(Policy):
...
@@ -34,7 +23,17 @@ class T5ModelPolicy(Policy):
return
self
.
model
return
self
.
model
def
module_policy
(
self
):
def
module_policy
(
self
):
base_policy
=
{
from
transformers.models.t5.modeling_t5
import
(
T5Attention
,
T5DenseActDense
,
T5DenseGatedActDense
,
T5LayerCrossAttention
,
T5LayerFF
,
T5LayerSelfAttention
,
T5Stack
,
)
return
{
T5Stack
:
T5Stack
:
ModulePolicyDescription
(
attribute_replacement
=
{},
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
param_replacement
=
[],
...
@@ -165,6 +164,8 @@ class T5ModelPolicy(Policy):
...
@@ -165,6 +164,8 @@ class T5ModelPolicy(Policy):
class
T5ForConditionalGenerationPolicy
(
T5ModelPolicy
):
class
T5ForConditionalGenerationPolicy
(
T5ModelPolicy
):
def
module_policy
(
self
):
def
module_policy
(
self
):
from
transformers
import
T5ForConditionalGeneration
policy
=
super
().
module_policy
()
policy
=
super
().
module_policy
()
new_item
=
{
new_item
=
{
...
...
colossalai/shardformer/policies/vit.py
View file @
44a190e6
from
typing
import
Dict
,
Union
from
typing
import
Dict
,
Union
import
torch.nn
as
nn
import
torch.nn
as
nn
from
transformers.models.vit.modeling_vit
import
ViTAttention
,
ViTEmbeddings
,
ViTLayer
,
ViTModel
from
colossalai.shardformer.layer
import
DropoutForReplicatedInput
,
FusedLayerNorm
,
Linear1D_Col
,
Linear1D_Row
from
colossalai.shardformer.layer
import
DropoutForReplicatedInput
,
FusedLayerNorm
,
Linear1D_Col
,
Linear1D_Row
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
__all__
=
[
'ViTPolicy'
]
class
ViTPolicy
(
Policy
):
class
ViTPolicy
(
Policy
):
...
@@ -25,7 +26,9 @@ class ViTPolicy(Policy):
...
@@ -25,7 +26,9 @@ class ViTPolicy(Policy):
return
self
.
model
return
self
.
model
def
module_policy
(
self
)
->
Dict
[
Union
[
str
,
nn
.
Module
],
ModulePolicyDescription
]:
def
module_policy
(
self
)
->
Dict
[
Union
[
str
,
nn
.
Module
],
ModulePolicyDescription
]:
base_policy
=
{
from
transformers.models.vit.modeling_vit
import
ViTEmbeddings
,
ViTLayer
return
{
ViTEmbeddings
:
ViTEmbeddings
:
ModulePolicyDescription
(
attribute_replacement
=
{},
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
param_replacement
=
[],
...
...
colossalai/shardformer/shard/shard_config.py
View file @
44a190e6
...
@@ -19,6 +19,7 @@ class ShardConfig:
...
@@ -19,6 +19,7 @@ class ShardConfig:
"""
"""
tensor_parallel_process_group
:
int
=
None
tensor_parallel_process_group
:
int
=
None
enable_fused_normalization
:
bool
=
False
enable_fused_normalization
:
bool
=
False
enable_all_optimization
:
bool
=
False
# TODO: add support for tensor parallel
# TODO: add support for tensor parallel
# pipeline_parallel_size: int
# pipeline_parallel_size: int
...
@@ -27,6 +28,21 @@ class ShardConfig:
...
@@ -27,6 +28,21 @@ class ShardConfig:
# inference_only: bool = True
# inference_only: bool = True
# gather_output: bool = True
# gather_output: bool = True
@
property
def
tensor_parallel_size
(
self
):
return
self
.
_tensor_parallel_size
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# get the parallel size
# get the parallel size
self
.
tensor_parallel_size
=
dist
.
get_world_size
(
self
.
tensor_parallel_process_group
)
self
.
_tensor_parallel_size
=
dist
.
get_world_size
(
self
.
tensor_parallel_process_group
)
# turn on all optimization if all_optimization is set to True
if
self
.
enable_all_optimization
:
self
.
_turn_on_all_optimization
()
def
_turn_on_all_optimization
(
self
):
"""
Turn on all optimization.
"""
# you can add all the optimization flag here
self
.
fused_layernorm
=
True
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment