Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
7740c55c
Commit
7740c55c
authored
Jun 22, 2023
by
FoolPlayer
Committed by
Frank Lee
Jul 04, 2023
Browse files
support kit use for bert/gpt test (#4055)
* support kit use for bert test * support kit test for gpt2
parent
f22ddace
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
319 additions
and
246 deletions
+319
-246
colossalai/shardformer/policies/autopolicy.py
colossalai/shardformer/policies/autopolicy.py
+15
-5
colossalai/shardformer/policies/bert.py
colossalai/shardformer/policies/bert.py
+15
-8
colossalai/shardformer/policies/gpt2.py
colossalai/shardformer/policies/gpt2.py
+79
-2
tests/kit/model_zoo/transformers/bert.py
tests/kit/model_zoo/transformers/bert.py
+102
-38
tests/kit/model_zoo/transformers/gpt.py
tests/kit/model_zoo/transformers/gpt.py
+54
-15
tests/test_shardformer/test_model/test_shard_bert.py
tests/test_shardformer/test_model/test_shard_bert.py
+24
-92
tests/test_shardformer/test_model/test_shard_gpt2.py
tests/test_shardformer/test_model/test_shard_gpt2.py
+30
-86
No files found.
colossalai/shardformer/policies/autopolicy.py
View file @
7740c55c
...
...
@@ -25,17 +25,19 @@ class PolicyLocation:
_POLICY_LIST
=
{
# BERT
"transformers.models.bert.modeling_bert.BertModel"
:
PolicyLocation
(
file_name
=
"bert"
,
class_name
=
"BertPolicy"
),
PolicyLocation
(
file_name
=
"bert"
,
class_name
=
"Bert
Model
Policy"
),
"transformers.models.bert.modeling_bert.BertForPreTraining"
:
PolicyLocation
(
file_name
=
"bert"
,
class_name
=
"BertForPretrainingPolicy"
),
"transformers.models.bert.modeling_bert.BertForMaskedLM"
:
PolicyLocation
(
file_name
=
"bert"
,
class_name
=
"BertForMaskedLMPolicy"
),
"transformers.models.bert.modeling_bert.BertLMHeadModel"
:
PolicyLocation
(
file_name
=
"bert"
,
class_name
=
"BertLMHeadModelPolicy"
),
"transformers.models.bert.modeling_bert.BertFor
NextSentencePrediction
"
:
PolicyLocation
(
file_name
=
"bert"
,
class_name
=
"BertFor
NextSentencePrediction
Policy"
),
"transformers.models.bert.modeling_bert.BertFor
MaskedLM
"
:
PolicyLocation
(
file_name
=
"bert"
,
class_name
=
"BertFor
MaskedLM
Policy"
),
"transformers.models.bert.modeling_bert.BertForSequenceClassification"
:
PolicyLocation
(
file_name
=
"bert"
,
class_name
=
"BertForSequenceClassificationPolicy"
),
"transformers.models.bert.modeling_bert.BertForTokenClassification"
:
PolicyLocation
(
file_name
=
"bert"
,
class_name
=
"BertForTokenClassificationPolicy"
),
"transformers.models.bert.modeling_bert.BertForNextSentencePrediction"
:
PolicyLocation
(
file_name
=
"bert"
,
class_name
=
"BertForNextSentencePredictionPolicy"
),
"transformers.models.bert.modeling_bert.BertForMultipleChoice"
:
PolicyLocation
(
file_name
=
"bert"
,
class_name
=
"BertForMultipleChoicePolicy"
),
...
...
@@ -58,6 +60,14 @@ _POLICY_LIST = {
# GPT2
"transformers.models.gpt2.modeling_gpt2.GPT2Model"
:
PolicyLocation
(
file_name
=
"gpt2"
,
class_name
=
"GPT2ModelPolicy"
),
"transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel"
:
PolicyLocation
(
file_name
=
"gpt2"
,
class_name
=
"GPT2LMHeadModelPolicy"
),
"transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel"
:
PolicyLocation
(
file_name
=
"gpt2"
,
class_name
=
"GPT2DoubleHeadsModelPolicy"
),
"transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification"
:
PolicyLocation
(
file_name
=
"gpt2"
,
class_name
=
"GPT2ForTokenClassificationPolicy"
),
"transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification"
:
PolicyLocation
(
file_name
=
"gpt2"
,
class_name
=
"GPT2ForSequenceClassificationPolicy"
),
}
...
...
colossalai/shardformer/policies/bert.py
View file @
7740c55c
...
...
@@ -131,8 +131,8 @@ class BertForPretrainingPolicy(BertPolicy):
return
self
.
model
# Bert
ForMaskedLM
class
Bert
ForMaskedLM
Policy
(
BertPolicy
):
# Bert
LMHeadModel
class
Bert
LMHeadModel
Policy
(
BertPolicy
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
...
...
@@ -162,8 +162,8 @@ class BertForMaskedLMPolicy(BertPolicy):
return
self
.
model
# Bert
LMHeadModel
class
Bert
LMHeadModel
Policy
(
BertPolicy
):
# Bert
ForMaskedLM
class
Bert
ForMaskedLM
Policy
(
BertPolicy
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
...
...
@@ -193,15 +193,22 @@ class BertLMHeadModelPolicy(BertPolicy):
return
self
.
model
# BertFor
NextSentencePred
iction
class
BertFor
NextSentencePred
ictionPolicy
(
BertPolicy
):
# BertFor
SequenceClassif
ic
a
tion
class
BertFor
SequenceClassif
ic
a
tionPolicy
(
BertPolicy
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
# BertForSequenceClassification
class
BertForSequenceClassificationPolicy
(
BertPolicy
):
# BertForTokenClassification
class
BertForTokenClassificationPolicy
(
BertPolicy
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
# BertForNextSentencePrediction
class
BertForNextSentencePredictionPolicy
(
BertPolicy
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
...
...
colossalai/shardformer/policies/gpt2.py
View file @
7740c55c
from
transformers.models.gpt2.modeling_gpt2
import
GPT2Block
,
GPT2Model
import
torch.nn
as
nn
from
transformers.models.gpt2.modeling_gpt2
import
GPT2Block
,
GPT2DoubleHeadsModel
,
GPT2LMHeadModel
,
GPT2Model
import
colossalai.shardformer.layer
as
col_nn
from
.._utils
import
getattr_
,
setattr_
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
...
...
@@ -82,7 +84,6 @@ class GPT2Policy(Policy):
}
def
new_model_class
(
self
):
return
self
.
model
def
postprocess
(
self
):
...
...
@@ -94,3 +95,79 @@ class GPT2ModelPolicy(GPT2Policy):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
# GPT2LMHeadModel
class
GPT2LMHeadModelPolicy
(
GPT2Policy
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
def
module_policy
(
self
):
module_policy
=
super
().
module_policy
()
addon_module
=
{
GPT2LMHeadModel
:
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
sub_module_replacement
=
[
SubModuleReplacementDescription
(
suffix
=
"lm_head"
,
target_module
=
col_nn
.
Linear1D_Col
,
kwargs
=
{
"gather_output"
:
True
})
])
}
module_policy
.
update
(
addon_module
)
return
module_policy
def
postprocess
(
self
):
binding_map
=
{
"transformer.wte.weight"
:
"lm_head.weight"
}
for
k
,
v
in
binding_map
.
items
():
param
=
getattr_
(
self
.
model
,
k
)
param
=
nn
.
Parameter
(
param
)
setattr_
(
self
.
model
,
k
,
param
)
setattr_
(
self
.
model
,
v
,
param
)
return
self
.
model
# GPT22DoubleHeadsModel
class
GPT2DoubleHeadsModelPolicy
(
GPT2Policy
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
def
module_policy
(
self
):
module_policy
=
super
().
module_policy
()
addon_module
=
{
GPT2DoubleHeadsModel
:
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
sub_module_replacement
=
[
SubModuleReplacementDescription
(
suffix
=
"lm_head"
,
target_module
=
col_nn
.
Linear1D_Col
,
kwargs
=
{
"gather_output"
:
True
})
])
}
module_policy
.
update
(
addon_module
)
return
module_policy
def
postprocess
(
self
):
binding_map
=
{
"transformer.wte.weight"
:
"lm_head.weight"
}
for
k
,
v
in
binding_map
.
items
():
param
=
getattr_
(
self
.
model
,
k
)
param
=
nn
.
Parameter
(
param
)
setattr_
(
self
.
model
,
k
,
param
)
setattr_
(
self
.
model
,
v
,
param
)
return
self
.
model
# GPT2ForTokenClassification
class
GPT2ForTokenClassificationPolicy
(
GPT2Policy
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
# GPT2ForSequenceClassification
class
GPT2ForSequenceClassificationPolicy
(
GPT2Policy
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
tests/kit/model_zoo/transformers/bert.py
View file @
7740c55c
...
...
@@ -6,83 +6,147 @@ from ..registry import ModelAttribute, model_zoo
# ===============================
# Register single-sentence BERT
# ===============================
BATCH_SIZE
=
2
SEQ_LENGTH
=
16
def
data_gen_fn
():
input_ids
=
torch
.
zeros
((
BATCH_SIZE
,
SEQ_LENGTH
),
dtype
=
torch
.
int64
)
token_type_ids
=
torch
.
zeros
((
BATCH_SIZE
,
SEQ_LENGTH
),
dtype
=
torch
.
int64
)
attention_mask
=
torch
.
zeros
((
BATCH_SIZE
,
SEQ_LENGTH
),
dtype
=
torch
.
int64
)
# define data gen function
def
data_gen
():
# Generated from following code snippet
#
# from transformers import BertTokenizer
# input = 'Hello, my dog is cute'
# tokenized_input = tokenizer(input, return_tensors='pt')
# input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask']
# token_type_ids = tokenized_input['token_type_ids']
input_ids
=
torch
.
tensor
([[
101
,
7592
,
1010
,
2026
,
3899
,
2003
,
10140
,
102
]],
dtype
=
torch
.
int64
)
token_type_ids
=
torch
.
tensor
([[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
]],
dtype
=
torch
.
int64
)
attention_mask
=
torch
.
tensor
([[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
]],
dtype
=
torch
.
int64
)
return
dict
(
input_ids
=
input_ids
,
token_type_ids
=
token_type_ids
,
attention_mask
=
attention_mask
)
def
data_gen_for_lm
():
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data
=
data_gen
()
data
[
'labels'
]
=
data
[
'input_ids'
].
clone
()
return
data
def
data_gen_for_pretraining
():
# pretraining data gen
# `next_sentence_label` is the label for next sentence prediction, 0 or 1
data
=
data_gen_for_lm
()
data
[
'next_sentence_label'
]
=
torch
.
tensor
([
1
],
dtype
=
torch
.
int64
)
return
data
def
data_gen_for_sequence_classification
():
# sequence classification data gen
# `labels` is the label for sequence classification, 0 or 1
data
=
data_gen
()
data
[
'labels'
]
=
torch
.
tensor
([
1
],
dtype
=
torch
.
int64
)
return
data
def
data_gen_for_token_classification
():
# token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1
data
=
data_gen
()
data
[
'labels'
]
=
torch
.
tensor
([[
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
]],
dtype
=
torch
.
int64
)
return
data
def
data_gen_for_mcq
():
# multiple choice question data gen
# Generated from following code snippet
#
# tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
# prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
# choice0 = "It is eaten with a fork and a knife."
# choice1 = "It is eaten while held in the hand."
# data = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
# data = {k: v.unsqueeze(0) for k, v in encoding.items()}
# data['labels'] = torch.tensor([0], dtype=torch.int64)
input_ids
=
torch
.
tensor
([[[
101
,
1999
,
3304
,
1010
,
10733
,
2366
,
1999
,
5337
,
10906
,
1010
,
2107
,
2004
,
2012
,
1037
,
4825
,
1010
,
2003
,
3591
,
4895
,
14540
,
6610
,
2094
,
1012
,
102
,
2009
,
2003
,
8828
,
2007
,
1037
,
9292
,
1998
,
1037
,
5442
,
1012
,
102
],
[
101
,
1999
,
3304
,
1010
,
10733
,
2366
,
1999
,
5337
,
10906
,
1010
,
2107
,
2004
,
2012
,
1037
,
4825
,
1010
,
2003
,
3591
,
4895
,
14540
,
6610
,
2094
,
1012
,
102
,
2009
,
2003
,
8828
,
2096
,
2218
,
1999
,
1996
,
2192
,
1012
,
102
,
0
]]])
token_type_ids
=
torch
.
tensor
(
[[[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
0
]]])
attention_mask
=
torch
.
tensor
(
[[[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
0
]]])
labels
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int64
)
return
dict
(
input_ids
=
input_ids
,
token_type_ids
=
token_type_ids
,
attention_mask
=
attention_mask
,
labels
=
labels
)
# define output transform function
output_transform_fn
=
lambda
x
:
x
config
=
transformers
.
BertConfig
(
hidden_size
=
128
,
num_hidden_layers
=
2
,
num_attention_heads
=
4
,
intermediate_size
=
256
)
# define loss funciton
loss_fn_for_bert_model
=
lambda
x
:
x
.
pooler_output
.
mean
()
loss_fn
=
lambda
x
:
x
.
loss
config
=
transformers
.
BertConfig
(
hidden_size
=
128
,
num_hidden_layers
=
2
,
num_attention_heads
=
4
,
intermediate_size
=
256
,
hidden_dropout_prob
=
0
,
attention_probs_dropout_prob
=
0
)
# register the BERT variants
model_zoo
.
register
(
name
=
'transformers_bert'
,
model_fn
=
lambda
:
transformers
.
BertModel
(
config
),
data_gen_fn
=
data_gen
_fn
,
data_gen_fn
=
data_gen
,
output_transform_fn
=
output_transform_fn
,
loss_fn
=
loss_fn_for_bert_model
,
model_attribute
=
ModelAttribute
(
has_control_flow
=
True
))
model_zoo
.
register
(
name
=
'transformers_bert_for_pretraining'
,
model_fn
=
lambda
:
transformers
.
BertForPreTraining
(
config
),
data_gen_fn
=
data_gen_f
n
,
data_gen_fn
=
data_gen_f
or_pretraining
,
output_transform_fn
=
output_transform_fn
,
loss_fn
=
loss_fn
,
model_attribute
=
ModelAttribute
(
has_control_flow
=
True
))
model_zoo
.
register
(
name
=
'transformers_bert_lm_head_model'
,
model_fn
=
lambda
:
transformers
.
BertLMHeadModel
(
config
),
data_gen_fn
=
data_gen_f
n
,
data_gen_fn
=
data_gen_f
or_lm
,
output_transform_fn
=
output_transform_fn
,
loss_fn
=
loss_fn
,
model_attribute
=
ModelAttribute
(
has_control_flow
=
True
))
model_zoo
.
register
(
name
=
'transformers_bert_for_masked_lm'
,
model_fn
=
lambda
:
transformers
.
BertForMaskedLM
(
config
),
data_gen_fn
=
data_gen_f
n
,
data_gen_fn
=
data_gen_f
or_lm
,
output_transform_fn
=
output_transform_fn
,
loss_fn
=
loss_fn
,
model_attribute
=
ModelAttribute
(
has_control_flow
=
True
))
model_zoo
.
register
(
name
=
'transformers_bert_for_sequence_classification'
,
model_fn
=
lambda
:
transformers
.
BertForSequenceClassification
(
config
),
data_gen_fn
=
data_gen_fn
,
data_gen_fn
=
data_gen_f
or_sequence_classificatio
n
,
output_transform_fn
=
output_transform_fn
,
loss_fn
=
loss_fn
,
model_attribute
=
ModelAttribute
(
has_control_flow
=
True
))
model_zoo
.
register
(
name
=
'transformers_bert_for_token_classification'
,
model_fn
=
lambda
:
transformers
.
BertForTokenClassification
(
config
),
data_gen_fn
=
data_gen_fn
,
data_gen_fn
=
data_gen_f
or_token_classificatio
n
,
output_transform_fn
=
output_transform_fn
,
loss_fn
=
loss_fn
,
model_attribute
=
ModelAttribute
(
has_control_flow
=
True
))
# ===============================
# Register multi-sentence BERT
# ===============================
def
data_gen_for_next_sentence
():
tokenizer
=
transformers
.
BertTokenizer
.
from_pretrained
(
"bert-base-uncased"
)
prompt
=
"In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
next_sentence
=
"The sky is blue due to the shorter wavelength of blue light."
encoding
=
tokenizer
(
prompt
,
next_sentence
,
return_tensors
=
"pt"
)
return
encoding
def
data_gen_for_mcq
():
tokenizer
=
transformers
.
BertTokenizer
.
from_pretrained
(
"bert-base-uncased"
)
prompt
=
"In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
choice0
=
"It is eaten with a fork and a knife."
choice1
=
"It is eaten while held in the hand."
encoding
=
tokenizer
([
prompt
,
prompt
],
[
choice0
,
choice1
],
return_tensors
=
"pt"
,
padding
=
True
)
encoding
=
{
k
:
v
.
unsqueeze
(
0
)
for
k
,
v
in
encoding
.
items
()}
return
encoding
# register the following models
model_zoo
.
register
(
name
=
'transformers_bert_for_next_sentence'
,
model_fn
=
lambda
:
transformers
.
BertForNextSentencePrediction
(
config
),
data_gen_fn
=
data_gen_for_
next_sentence
,
data_gen_fn
=
data_gen_for_
sequence_classification
,
output_transform_fn
=
output_transform_fn
,
loss_fn
=
loss_fn
,
model_attribute
=
ModelAttribute
(
has_control_flow
=
True
))
model_zoo
.
register
(
name
=
'transformers_bert_for_mcq'
,
model_fn
=
lambda
:
transformers
.
BertForMultipleChoice
(
config
),
data_gen_fn
=
data_gen_for_mcq
,
output_transform_fn
=
output_transform_fn
,
loss_fn
=
loss_fn
,
model_attribute
=
ModelAttribute
(
has_control_flow
=
True
))
tests/kit/model_zoo/transformers/gpt.py
View file @
7740c55c
...
...
@@ -11,47 +11,86 @@ SEQ_LENGTH = 16
def
data_gen
():
input_ids
=
torch
.
zeros
((
BATCH_SIZE
,
SEQ_LENGTH
),
dtype
=
torch
.
int64
)
token_type_ids
=
torch
.
zeros
((
BATCH_SIZE
,
SEQ_LENGTH
),
dtype
=
torch
.
int64
)
attention_mask
=
torch
.
zeros
((
BATCH_SIZE
,
SEQ_LENGTH
),
dtype
=
torch
.
int64
)
return
dict
(
input_ids
=
input_ids
,
token_type_ids
=
token_type_ids
,
attention_mask
=
attention_mask
)
# Generated from following code snippet
#
# from transformers import GPT2Tokenizer
# input = 'Hello, my dog is cute'
# tokenized_input = tokenizer(input, return_tensors='pt')
# input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask']
input_ids
=
torch
.
tensor
([[
15496
,
11
,
616
,
3290
,
318
,
13779
]],
dtype
=
torch
.
int64
)
attention_mask
=
torch
.
tensor
([[
1
,
1
,
1
,
1
,
1
,
1
]],
dtype
=
torch
.
int64
)
return
dict
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)
def
seq_classification_
data_gen
():
#
batch sizes should be 1 if no padding token is defined.
input_ids
=
torch
.
zeros
((
1
,
SEQ_LENGTH
),
dtype
=
torch
.
int64
)
token_type_ids
=
torch
.
zeros
((
1
,
SEQ_LENGTH
),
dtype
=
torch
.
int64
)
at
tention_mask
=
torch
.
zeros
((
1
,
SEQ_LENGTH
),
dtype
=
torch
.
int64
)
return
d
ict
(
input_ids
=
input_ids
,
token_type_ids
=
token_type_ids
,
attention_mask
=
attention_mask
)
def
data_gen
_for_lm
():
#
LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data
=
data_gen
(
)
d
at
a
[
'labels'
]
=
data
[
'input_ids'
].
clone
(
)
return
d
ata
def
data_gen_for_token_classification
():
# token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1
data
=
data_gen
()
data
[
'labels'
]
=
torch
.
tensor
([[
0
,
0
,
0
,
0
,
0
,
0
]],
dtype
=
torch
.
int64
)
return
data
def
data_gen_for_sequence_classification
():
# sequence classification data gen
data
=
data_gen
()
data
[
'labels'
]
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int64
)
return
data
# define output transform function
output_transform_fn
=
lambda
x
:
x
config
=
transformers
.
GPT2Config
(
n_position
=
64
,
n_layer
=
2
,
n_head
=
4
)
# define loss function
loss_fn_for_gpt2_model
=
lambda
x
:
x
.
last_hidden_state
.
mean
()
loss_fn
=
lambda
x
:
x
.
loss
config
=
transformers
.
GPT2Config
(
n_layer
=
2
,
n_head
=
4
,
vocab_size
=
50258
,
attn_pdrop
=
0
,
embd_pdrop
=
0
,
resid_pdrop
=
0
,
summary_first_dropout
=
0
,
hidden_dropout
=
0
,
problem_type
=
"single_label_classification"
)
# register the following models
model_zoo
.
register
(
name
=
'transformers_gpt'
,
model_fn
=
lambda
:
transformers
.
GPT2Model
(
config
),
data_gen_fn
=
data_gen
,
output_transform_fn
=
output_transform_fn
,
loss_fn
=
loss_fn_for_gpt2_model
,
model_attribute
=
ModelAttribute
(
has_control_flow
=
True
))
model_zoo
.
register
(
name
=
'transformers_gpt_lm'
,
model_fn
=
lambda
:
transformers
.
GPT2LMHeadModel
(
config
),
data_gen_fn
=
data_gen
,
data_gen_fn
=
data_gen
_for_lm
,
output_transform_fn
=
output_transform_fn
,
loss_fn
=
loss_fn
,
model_attribute
=
ModelAttribute
(
has_control_flow
=
True
))
model_zoo
.
register
(
name
=
'transformers_gpt_double_heads'
,
model_fn
=
lambda
:
transformers
.
GPT2DoubleHeadsModel
(
config
),
data_gen_fn
=
data_gen
,
data_gen_fn
=
data_gen
_for_lm
,
output_transform_fn
=
output_transform_fn
,
loss_fn
=
loss_fn
,
model_attribute
=
ModelAttribute
(
has_control_flow
=
True
))
model_zoo
.
register
(
name
=
'transformers_gpt_for_token_classification'
,
model_fn
=
lambda
:
transformers
.
GPT2ForTokenClassification
(
config
),
data_gen_fn
=
data_gen
,
data_gen_fn
=
data_gen
_for_token_classification
,
output_transform_fn
=
output_transform_fn
,
loss_fn
=
loss_fn
,
model_attribute
=
ModelAttribute
(
has_control_flow
=
True
))
model_zoo
.
register
(
name
=
'transformers_gpt_for_sequence_classification'
,
model_fn
=
lambda
:
transformers
.
GPT2ForSequenceClassification
(
config
),
data_gen_fn
=
seq
_classification
_data_gen
,
data_gen_fn
=
data_gen_for_sequence
_classification
,
output_transform_fn
=
output_transform_fn
,
loss_fn
=
loss_fn
,
model_attribute
=
ModelAttribute
(
has_control_flow
=
True
))
tests/test_shardformer/test_model/test_shard_bert.py
View file @
7740c55c
import
copy
import
os
import
pytest
import
torch
from
transformers
import
(
AutoTokenizer
,
BertConfig
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForSequenceClassification
,
BertLMHeadModel
,
BertModel
,
)
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer
import
ShardConfig
,
ShardFormer
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
os
.
environ
[
'TRANSFORMERS_NO_ADVISORY_WARNINGS'
]
=
'true'
CONFIG
=
dict
(
parallel
=
dict
(
data
=
1
,
pipeline
=
1
,
tensor
=
dict
(
size
=
2
,
mode
=
'1d'
)),)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"bert-base-uncased"
)
def
build_model
(
world_size
,
model_fn
):
config
=
BertConfig
()
config
.
hidden_dropout_prob
=
0
config
.
attention_probs_dropout_prob
=
0
org_model
=
model_fn
(
config
=
config
)
org_model_forshard
=
copy
.
deepcopy
(
org_model
)
org_model
.
to
(
'cuda'
)
# TODO: no need to transfer to cuda
org_model_forshard
.
to
(
'cuda'
)
shard_config
=
ShardConfig
(
tensor_parallel_size
=
world_size
,)
shard_former
=
ShardFormer
(
shard_config
=
shard_config
)
shard_former
.
init_distributed
()
sharded_model
=
shard_former
.
shard_model
(
org_model_forshard
).
to
(
'cuda'
)
return
org_model
,
sharded_model
from
colossalai.testing
import
assert_hf_output_close
,
clear_cache_before_run
,
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
model_zoo
from
tests.test_shardformer.test_model._utils
import
build_model
,
run_forward
def
check_forward
(
org_model
,
sharded_model
):
input
=
'Hello, my dog is cute'
tokenized_input
=
tokenizer
(
input
,
return_tensors
=
'pt'
).
to
(
'cuda'
)
#orgin model
org_model
.
eval
()
org_out
=
org_model
(
**
tokenized_input
)
def
check_forward_backward
(
org_model
,
sharded_model
,
data_gen_fn
,
output_transform_fn
,
loss_fn
):
# check forward
org_output
,
org_loss
,
shard_output
,
shard_loss
=
run_forward
(
org_model
,
sharded_model
,
data_gen_fn
,
output_transform_fn
,
loss_fn
)
assert_hf_output_close
(
org_output
,
shard_output
)
#shard model
sharded_model
.
eval
()
shard_out
=
sharded_model
(
**
tokenized_input
)
assert
torch
.
allclose
(
org_out
[
0
],
shard_out
[
0
],
atol
=
1e-5
),
f
"shard model output is not equal to orgin model output
\n
{
org_out
[
0
]
}
\n
{
shard_out
[
0
]
}
"
def
check_backward
(
org_model
,
sharded_model
):
# prepare input
input
=
'Hello, my dog is cute'
tokenized_input
=
tokenizer
(
input
,
return_tensors
=
'pt'
).
to
(
'cuda'
)
labels
=
tokenized_input
[
'input_ids'
].
clone
()
labels
[
labels
==
tokenizer
.
pad_token_id
]
=
-
100
tokenized_input
[
'labels'
]
=
labels
#orgin model
org_model
.
train
()
org_out
=
org_model
(
**
tokenized_input
)
org_loss
=
org_out
.
loss
# do backward
org_loss
.
backward
()
org_grad
=
org_model
.
bert
.
encoder
.
layer
[
0
].
attention
.
self
.
query
.
weight
.
grad
#shard model
sharded_model
.
train
()
shard_out
=
sharded_model
(
**
tokenized_input
)
shard_loss
=
shard_out
.
loss
shard_loss
.
backward
()
shard_grad
=
sharded_model
.
bert
.
encoder
.
layer
[
0
].
attention
.
self
.
query
.
weight
.
grad
# check grad equality
if
org_model
.
__class__
.
__name__
==
'BertModel'
:
org_grad
=
org_model
.
encoder
.
layer
[
0
].
attention
.
self
.
query
.
weight
.
grad
shard_grad
=
sharded_model
.
encoder
.
layer
[
0
].
attention
.
self
.
query
.
weight
.
grad
else
:
org_grad
=
org_model
.
bert
.
encoder
.
layer
[
0
].
attention
.
self
.
query
.
weight
.
grad
shard_grad
=
sharded_model
.
bert
.
encoder
.
layer
[
0
].
attention
.
self
.
query
.
weight
.
grad
shard_grad_list
=
[
torch
.
zeros
([
*
shard_grad
.
shape
]).
to
(
'cuda'
)
for
_
in
range
(
2
)]
shard_grad
=
torch
.
distributed
.
all_gather
(
shard_grad_list
,
shard_grad
)
...
...
@@ -89,36 +33,24 @@ def check_backward(org_model, sharded_model):
assert
torch
.
allclose
(
org_loss
,
shard_loss
,
atol
=
1e-5
),
f
"shard model loss is not equal to orgin model loss
\n
{
org_loss
}
\n
{
shard_loss
}
"
assert
torch
.
allclose
(
org_grad
,
all_shard_grad
,
atol
=
1e-5
),
f
"shard model grad is not equal to orgin model grad
\n
{
org_grad
}
\n
{
shard_grad
}
"
atol
=
1e-5
),
f
"shard model grad is not equal to orgin model grad
\n
{
org_grad
}
\n
{
all_
shard_grad
}
"
def
check_bert
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
forward_list
=
[
BertForMaskedLM
,
BertForPreTraining
,
BertLMHeadModel
,
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
# TODO: do not work yet
# BertModel,
# BertForSequenceClassification
# BertForNextSentencePrediction,
]
backward_lsit
=
[
BertForMaskedLM
,
BertLMHeadModel
]
for
model_fn
in
forward_list
:
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_bert'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
org_model
,
sharded_model
=
build_model
(
world_size
,
model_fn
)
check_forward
(
org_model
,
sharded_model
)
if
model_fn
in
backward_lsit
:
check_backward
(
org_model
,
sharded_model
)
check_forward_backward
(
org_model
,
sharded_model
,
data_gen_fn
,
output_transform_fn
,
loss_fn
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
def
test_bert
():
spawn
(
check_bert
,
2
)
...
...
tests/test_shardformer/test_model/test_shard_gpt2.py
View file @
7740c55c
import
copy
import
os
import
pytest
import
torch
from
transformers
import
AutoTokenizer
,
GPT2Config
,
GPT2Model
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer
import
ShardConfig
,
ShardFormer
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
os
.
environ
[
'TRANSFORMERS_NO_ADVISORY_WARNINGS'
]
=
'true'
CONFIG
=
dict
(
parallel
=
dict
(
data
=
1
,
pipeline
=
1
,
tensor
=
dict
(
size
=
2
,
mode
=
'1d'
)),)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"bert-base-uncased"
)
def
build_model
(
world_size
,
model_fn
):
config
=
GPT2Config
()
config
.
attn_pdrop
=
0
config
.
embd_pdrop
=
0
config
.
resid_pdrop
=
0
config
.
summary_first_dropout
org_model
=
model_fn
(
config
=
config
)
org_model_forshard
=
copy
.
deepcopy
(
org_model
)
org_model
.
to
(
'cuda'
)
# TODO: no need to transfer to cuda
org_model_forshard
.
to
(
'cuda'
)
shard_config
=
ShardConfig
(
tensor_parallel_size
=
world_size
,)
shard_former
=
ShardFormer
(
shard_config
=
shard_config
)
shard_former
.
init_distributed
()
sharded_model
=
shard_former
.
shard_model
(
org_model_forshard
).
to
(
'cuda'
)
return
org_model
,
sharded_model
from
colossalai.testing
import
assert_hf_output_close
,
clear_cache_before_run
,
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
model_zoo
from
tests.test_shardformer.test_model._utils
import
build_model
,
run_forward
def
check_forward
(
org_model
,
sharded_model
):
input
=
'Hello, my dog is cute'
tokenized_input
=
tokenizer
(
input
,
return_tensors
=
'pt'
).
to
(
'cuda'
)
#orgin model
org_model
.
eval
()
org_out
=
org_model
(
**
tokenized_input
)
def
check_forward_backward
(
org_model
,
sharded_model
,
data_gen_fn
,
output_transform_fn
,
loss_fn
):
# check forward
org_output
,
org_loss
,
shard_output
,
shard_loss
=
run_forward
(
org_model
,
sharded_model
,
data_gen_fn
,
output_transform_fn
,
loss_fn
)
assert_hf_output_close
(
org_output
,
shard_output
,
ignore_keys
=
[
'past_key_values'
])
#shard model
sharded_model
.
eval
()
shard_out
=
sharded_model
(
**
tokenized_input
)
assert
torch
.
allclose
(
org_out
[
0
],
shard_out
[
0
],
atol
=
1e-5
),
f
"shard model output is not equal to orgin model output
\n
{
org_out
[
0
]
}
\n
{
shard_out
[
0
]
}
"
def
check_backward
(
org_model
,
sharded_model
):
# prepare input
input
=
'Hello, my dog is cute'
tokenized_input
=
tokenizer
(
input
,
return_tensors
=
'pt'
).
to
(
'cuda'
)
labels
=
tokenized_input
[
'input_ids'
].
clone
()
labels
[
labels
==
tokenizer
.
pad_token_id
]
=
-
100
# tokenized_input['labels'] = labels
#orgin model
org_model
.
train
()
org_out
=
org_model
(
**
tokenized_input
)
org_loss
=
org_out
.
loss
# do backward
org_loss
.
backward
()
org_grad
=
org_model
.
h
[
0
].
attn
.
c_attn
.
weight
.
grad
#shard model
sharded_model
.
train
()
shard_out
=
sharded_model
(
**
tokenized_input
)
shard_loss
=
shard_out
.
loss
shard_loss
.
backward
()
shard_grad
=
sharded_model
.
h
[
0
].
attn
.
c_attn
.
weight
.
grad
# check grad equality
if
org_model
.
__class__
.
__name__
==
'GPT2Model'
:
org_grad
=
org_model
.
h
[
0
].
attn
.
c_attn
.
weight
.
grad
shard_grad
=
sharded_model
.
h
[
0
].
attn
.
c_attn
.
weight
.
grad
.
transpose
(
0
,
1
).
contiguous
()
else
:
org_grad
=
org_model
.
transformer
.
h
[
0
].
mlp
.
c_fc
.
weight
.
grad
shard_grad
=
sharded_model
.
transformer
.
h
[
0
].
mlp
.
c_fc
.
weight
.
grad
.
transpose
(
0
,
1
).
contiguous
()
shard_grad_list
=
[
torch
.
zeros
([
*
shard_grad
.
shape
]).
to
(
'cuda'
)
for
_
in
range
(
2
)]
shard_grad
=
torch
.
distributed
.
all_gather
(
shard_grad_list
,
shard_grad
)
all_shard_grad
=
torch
.
cat
(
shard_grad_list
,
dim
=
0
)
all_shard_grad
=
torch
.
cat
(
shard_grad_list
,
dim
=
1
)
assert
torch
.
allclose
(
org_loss
,
shard_loss
,
atol
=
1e-5
),
f
"shard model loss is not equal to orgin model loss
\n
{
org_loss
}
\n
{
shard_loss
}
"
assert
torch
.
allclose
(
org_grad
,
all_shard_grad
,
atol
=
1e-5
),
f
"shard model grad is not equal to orgin model grad
\n
{
org_grad
}
\n
{
shard_grad
}
"
atol
=
1e-5
),
f
"shard model grad is not equal to orgin model grad
\n
{
org_grad
}
\n
{
all_
shard_grad
}
"
def
check_
bert
(
rank
,
world_size
,
port
):
def
check_
gpt2
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
forward_list
=
[
GPT2Model
,
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
# TODO: do not work yet
# BertModel,
# BertForSequenceClassification
# BertForNextSentencePrediction,
]
backward_lsit
=
[]
for
model_fn
in
forward_list
:
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_gpt'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
print
(
name
)
# if name == 'transformers_gpt':
# continue
org_model
,
sharded_model
=
build_model
(
world_size
,
model_fn
)
check_forward
(
org_model
,
sharded_model
)
if
model_fn
in
backward_lsit
:
check_backward
(
org_model
,
sharded_model
)
check_forward_backward
(
org_model
,
sharded_model
,
data_gen_fn
,
output_transform_fn
,
loss_fn
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
def
test_gpt2
():
spawn
(
check_
bert
,
2
)
spawn
(
check_
gpt2
,
2
)
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment