Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
da3cef27
Commit
da3cef27
authored
Jul 27, 2023
by
Baizhou Zhang
Committed by
Hongxin Liu
Aug 15, 2023
Browse files
[pipeline] fix return_dict/fix pure_pipeline_test (#4331)
parent
411cf1d2
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
29 additions
and
53 deletions
+29
-53
colossalai/shardformer/modeling/bert.py
colossalai/shardformer/modeling/bert.py
+4
-29
colossalai/shardformer/modeling/bloom.py
colossalai/shardformer/modeling/bloom.py
+0
-12
colossalai/shardformer/modeling/gpt2.py
colossalai/shardformer/modeling/gpt2.py
+2
-0
colossalai/shardformer/policies/opt.py
colossalai/shardformer/policies/opt.py
+20
-8
tests/test_shardformer/test_model/test_pure_pipeline.py
tests/test_shardformer/test_model/test_pure_pipeline.py
+3
-4
No files found.
colossalai/shardformer/modeling/bert.py
View file @
da3cef27
import
warnings
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -277,9 +278,6 @@ class BertPipelineForwards:
...
@@ -277,9 +278,6 @@ class BertPipelineForwards:
if
output_hidden_states
:
if
output_hidden_states
:
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
output_hidden_states
=
False
output_hidden_states
=
False
if
return_dict
:
logger
.
warning_once
(
'return_dict is not supported for pipeline models at the moment'
)
return_dict
=
False
outputs
=
BertPipelineForwards
.
bert_model_forward
(
outputs
=
BertPipelineForwards
.
bert_model_forward
(
self
.
bert
,
self
.
bert
,
...
@@ -387,9 +385,6 @@ class BertPipelineForwards:
...
@@ -387,9 +385,6 @@ class BertPipelineForwards:
if
output_hidden_states
:
if
output_hidden_states
:
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
output_hidden_states
=
False
output_hidden_states
=
False
if
return_dict
:
logger
.
warning_once
(
'return_dict is not supported for pipeline models at the moment'
)
return_dict
=
False
outputs
=
BertPipelineForwards
.
bert_model_forward
(
outputs
=
BertPipelineForwards
.
bert_model_forward
(
self
.
bert
,
self
.
bert
,
...
@@ -478,9 +473,6 @@ class BertPipelineForwards:
...
@@ -478,9 +473,6 @@ class BertPipelineForwards:
if
output_hidden_states
:
if
output_hidden_states
:
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
output_hidden_states
=
False
output_hidden_states
=
False
if
return_dict
:
logger
.
warning_once
(
'return_dict is not supported for pipeline models at the moment'
)
return_dict
=
False
outputs
=
BertPipelineForwards
.
bert_model_forward
(
outputs
=
BertPipelineForwards
.
bert_model_forward
(
self
.
bert
,
self
.
bert
,
...
@@ -579,16 +571,15 @@ class BertPipelineForwards:
...
@@ -579,16 +571,15 @@ class BertPipelineForwards:
FutureWarning
,
FutureWarning
,
)
)
labels
=
kwargs
.
pop
(
"next_sentence_label"
)
labels
=
kwargs
.
pop
(
"next_sentence_label"
)
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
if
output_attentions
:
if
output_attentions
:
logger
.
warning_once
(
'output_attentions=True is not supported for pipeline models at the moment.'
)
logger
.
warning_once
(
'output_attentions=True is not supported for pipeline models at the moment.'
)
output_attentions
=
False
output_attentions
=
False
if
output_hidden_states
:
if
output_hidden_states
:
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
output_hidden_states
=
False
output_hidden_states
=
False
if
return_dict
:
logger
.
warning_once
(
'return_dict is not supported for pipeline models at the moment'
)
return_dict
=
False
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
outputs
=
BertPipelineForwards
.
bert_model_forward
(
self
.
bert
,
outputs
=
BertPipelineForwards
.
bert_model_forward
(
self
.
bert
,
input_ids
,
input_ids
,
...
@@ -661,10 +652,6 @@ class BertPipelineForwards:
...
@@ -661,10 +652,6 @@ class BertPipelineForwards:
if
output_hidden_states
:
if
output_hidden_states
:
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
output_hidden_states
=
False
output_hidden_states
=
False
if
return_dict
:
logger
.
warning_once
(
'return_dict is not supported for pipeline models at the moment'
)
return_dict
=
False
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
outputs
=
BertPipelineForwards
.
bert_model_forward
(
self
.
bert
,
outputs
=
BertPipelineForwards
.
bert_model_forward
(
self
.
bert
,
input_ids
,
input_ids
,
...
@@ -753,10 +740,6 @@ class BertPipelineForwards:
...
@@ -753,10 +740,6 @@ class BertPipelineForwards:
if
output_hidden_states
:
if
output_hidden_states
:
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
output_hidden_states
=
False
output_hidden_states
=
False
if
return_dict
:
logger
.
warning_once
(
'return_dict is not supported for pipeline models at the moment'
)
return_dict
=
False
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
outputs
=
BertPipelineForwards
.
bert_model_forward
(
outputs
=
BertPipelineForwards
.
bert_model_forward
(
self
.
bert
,
self
.
bert
,
...
@@ -832,10 +815,6 @@ class BertPipelineForwards:
...
@@ -832,10 +815,6 @@ class BertPipelineForwards:
if
output_hidden_states
:
if
output_hidden_states
:
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
output_hidden_states
=
False
output_hidden_states
=
False
if
return_dict
:
logger
.
warning_once
(
'return_dict is not supported for pipeline models at the moment'
)
return_dict
=
False
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
# in our pipeline design,input ids are copied for every stage and shouldn't be none
# in our pipeline design,input ids are copied for every stage and shouldn't be none
# the input_ids for multiple choice model is [batch_size, num_choices, sequence_length]
# the input_ids for multiple choice model is [batch_size, num_choices, sequence_length]
...
@@ -928,10 +907,6 @@ class BertPipelineForwards:
...
@@ -928,10 +907,6 @@ class BertPipelineForwards:
if
output_hidden_states
:
if
output_hidden_states
:
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
output_hidden_states
=
False
output_hidden_states
=
False
if
return_dict
:
logger
.
warning_once
(
'return_dict is not supported for pipeline models at the moment'
)
return_dict
=
False
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
outputs
=
BertPipelineForwards
.
bert_model_forward
(
outputs
=
BertPipelineForwards
.
bert_model_forward
(
self
.
bert
,
self
.
bert
,
...
...
colossalai/shardformer/modeling/bloom.py
View file @
da3cef27
...
@@ -313,9 +313,6 @@ class BloomPipelineForwards:
...
@@ -313,9 +313,6 @@ class BloomPipelineForwards:
if
output_hidden_states
:
if
output_hidden_states
:
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
output_hidden_states
=
False
output_hidden_states
=
False
if
return_dict
:
logger
.
warning_once
(
'return_dict is not supported for pipeline models at the moment'
)
return_dict
=
False
transformer_outputs
=
BloomPipelineForwards
.
bloom_model_forward
(
self
.
transformer
,
transformer_outputs
=
BloomPipelineForwards
.
bloom_model_forward
(
self
.
transformer
,
input_ids
,
input_ids
,
...
@@ -411,9 +408,6 @@ class BloomPipelineForwards:
...
@@ -411,9 +408,6 @@ class BloomPipelineForwards:
if
output_hidden_states
:
if
output_hidden_states
:
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
output_hidden_states
=
False
output_hidden_states
=
False
if
return_dict
:
logger
.
warning_once
(
'return_dict is not supported for pipeline models at the moment'
)
return_dict
=
False
transformer_outputs
=
BloomPipelineForwards
.
bloom_model_forward
(
transformer_outputs
=
BloomPipelineForwards
.
bloom_model_forward
(
self
.
transformer
,
self
.
transformer
,
...
@@ -537,9 +531,6 @@ class BloomPipelineForwards:
...
@@ -537,9 +531,6 @@ class BloomPipelineForwards:
if
output_hidden_states
:
if
output_hidden_states
:
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
output_hidden_states
=
False
output_hidden_states
=
False
if
return_dict
:
logger
.
warning_once
(
'return_dict is not supported for pipeline models at the moment'
)
return_dict
=
False
transformer_outputs
=
BloomPipelineForwards
.
bloom_model_forward
(
transformer_outputs
=
BloomPipelineForwards
.
bloom_model_forward
(
self
.
transformer
,
self
.
transformer
,
...
@@ -626,9 +617,6 @@ class BloomPipelineForwards:
...
@@ -626,9 +617,6 @@ class BloomPipelineForwards:
if
output_hidden_states
:
if
output_hidden_states
:
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
output_hidden_states
=
False
output_hidden_states
=
False
if
return_dict
:
logger
.
warning_once
(
'return_dict is not supported for pipeline models at the moment'
)
return_dict
=
False
outputs
=
BloomPipelineForwards
.
bloom_model_forward
(
outputs
=
BloomPipelineForwards
.
bloom_model_forward
(
self
.
transformer
,
self
.
transformer
,
...
...
colossalai/shardformer/modeling/gpt2.py
View file @
da3cef27
...
@@ -52,6 +52,8 @@ class GPT2PipelineForwards:
...
@@ -52,6 +52,8 @@ class GPT2PipelineForwards:
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
# Please refer to original code of transformers for more details.
# Please refer to original code of transformers for more details.
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
# Preprocess passed in arguments
# Preprocess passed in arguments
...
...
colossalai/shardformer/policies/opt.py
View file @
da3cef27
...
@@ -8,6 +8,18 @@ import torch
...
@@ -8,6 +8,18 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch
import
Tensor
,
nn
from
torch
import
Tensor
,
nn
from
torch.nn
import
BCEWithLogitsLoss
,
CrossEntropyLoss
,
MSELoss
from
torch.nn
import
BCEWithLogitsLoss
,
CrossEntropyLoss
,
MSELoss
from
transformers.modeling_outputs
import
(
BaseModelOutputWithPast
,
CausalLMOutputWithPast
,
QuestionAnsweringModelOutput
,
SequenceClassifierOutputWithPast
,
)
from
transformers.models.opt.modeling_opt
import
(
OPTForCausalLM
,
OPTForQuestionAnswering
,
OPTForSequenceClassification
,
OPTModel
,
)
from
colossalai.pipeline.stage_manager
import
PipelineStageManager
from
colossalai.pipeline.stage_manager
import
PipelineStageManager
from
colossalai.shardformer.layer
import
FusedLayerNorm
,
Linear1D_Col
,
Linear1D_Row
,
VocabParallelEmbedding1D
from
colossalai.shardformer.layer
import
FusedLayerNorm
,
Linear1D_Col
,
Linear1D_Row
,
VocabParallelEmbedding1D
...
@@ -317,7 +329,7 @@ class OPTPipelineForwards:
...
@@ -317,7 +329,7 @@ class OPTPipelineForwards:
@
staticmethod
@
staticmethod
def
opt_model_forward
(
def
opt_model_forward
(
self
:
'
OPTModel
'
,
self
:
OPTModel
,
input_ids
:
torch
.
LongTensor
=
None
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -330,7 +342,7 @@ class OPTPipelineForwards:
...
@@ -330,7 +342,7 @@ class OPTPipelineForwards:
stage_manager
:
Optional
[
PipelineStageManager
]
=
None
,
stage_manager
:
Optional
[
PipelineStageManager
]
=
None
,
hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
stage_index
:
Optional
[
List
[
int
]]
=
None
,
stage_index
:
Optional
[
List
[
int
]]
=
None
,
)
->
Union
[
Tuple
,
'
BaseModelOutputWithPast
'
]:
)
->
Union
[
Tuple
,
BaseModelOutputWithPast
]:
'''
'''
This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward
This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward
'''
'''
...
@@ -506,7 +518,7 @@ class OPTPipelineForwards:
...
@@ -506,7 +518,7 @@ class OPTPipelineForwards:
@
staticmethod
@
staticmethod
def
opt_for_causal_lm_forward
(
def
opt_for_causal_lm_forward
(
self
:
'
OPTForCausalLM
'
,
self
:
OPTForCausalLM
,
input_ids
:
torch
.
LongTensor
=
None
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -520,7 +532,7 @@ class OPTPipelineForwards:
...
@@ -520,7 +532,7 @@ class OPTPipelineForwards:
stage_manager
:
Optional
[
PipelineStageManager
]
=
None
,
stage_manager
:
Optional
[
PipelineStageManager
]
=
None
,
hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
stage_index
:
Optional
[
List
[
int
]]
=
None
,
stage_index
:
Optional
[
List
[
int
]]
=
None
,
)
->
Union
[
Tuple
,
'
CausalLMOutputWithPast
'
]:
)
->
Union
[
Tuple
,
CausalLMOutputWithPast
]:
r
"""
r
"""
Args:
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
...
@@ -646,7 +658,7 @@ class OPTPipelineForwards:
...
@@ -646,7 +658,7 @@ class OPTPipelineForwards:
@
staticmethod
@
staticmethod
def
opt_for_sequence_classification_forward
(
def
opt_for_sequence_classification_forward
(
self
:
'
OPTForSequenceClassification
'
,
self
:
OPTForSequenceClassification
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
head_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
head_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
...
@@ -660,7 +672,7 @@ class OPTPipelineForwards:
...
@@ -660,7 +672,7 @@ class OPTPipelineForwards:
stage_manager
:
Optional
[
PipelineStageManager
]
=
None
,
stage_manager
:
Optional
[
PipelineStageManager
]
=
None
,
hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
stage_index
:
Optional
[
List
[
int
]]
=
None
,
stage_index
:
Optional
[
List
[
int
]]
=
None
,
)
->
Union
[
Tuple
,
'
SequenceClassifierOutputWithPast
'
]:
)
->
Union
[
Tuple
,
SequenceClassifierOutputWithPast
]:
r
"""
r
"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
...
@@ -746,7 +758,7 @@ class OPTPipelineForwards:
...
@@ -746,7 +758,7 @@ class OPTPipelineForwards:
@
staticmethod
@
staticmethod
def
opt_for_question_answering_forward
(
def
opt_for_question_answering_forward
(
self
:
'
OPTForQuestionAnswering
'
,
self
:
OPTForQuestionAnswering
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
head_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
head_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
...
@@ -761,7 +773,7 @@ class OPTPipelineForwards:
...
@@ -761,7 +773,7 @@ class OPTPipelineForwards:
stage_manager
:
Optional
[
PipelineStageManager
]
=
None
,
stage_manager
:
Optional
[
PipelineStageManager
]
=
None
,
hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
stage_index
:
Optional
[
List
[
int
]]
=
None
,
stage_index
:
Optional
[
List
[
int
]]
=
None
,
)
->
Union
[
Tuple
,
'
QuestionAnsweringModelOutput
'
]:
)
->
Union
[
Tuple
,
QuestionAnsweringModelOutput
]:
r
"""
r
"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Labels for position (index) of the start of the labelled span for computing the token classification loss.
...
...
tests/test_shardformer/test_model/test_pure_pipeline.py
View file @
da3cef27
import
copy
import
copy
import
random
import
random
from
contextlib
import
nullcontext
from
typing
import
Any
,
Callable
,
Iterator
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Callable
,
Iterator
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
numpy
as
np
...
@@ -100,8 +99,8 @@ class data_loader():
...
@@ -100,8 +99,8 @@ class data_loader():
return
torch
.
ones
((
4
,
128
),
dtype
=
torch
.
int
).
cuda
()
*
10
return
torch
.
ones
((
4
,
128
),
dtype
=
torch
.
int
).
cuda
()
*
10
def
loss
(
x
,
y
):
def
loss
(
y
,
x
):
return
(
x
[
0
].
float
().
mean
()
-
y
[
0
].
float
().
mean
())
return
(
y
[
0
].
float
().
mean
()
-
x
[
0
].
float
().
mean
())
@
parameterize
(
'enable_fused_normalization'
,
[
False
])
@
parameterize
(
'enable_fused_normalization'
,
[
False
])
...
@@ -137,7 +136,7 @@ def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_la
...
@@ -137,7 +136,7 @@ def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_la
batch
=
next
(
data_iter
)
batch
=
next
(
data_iter
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
y
=
model_copy
(
batch
)
y
=
model_copy
(
batch
)
org_loss
=
loss
(
batch
,
y
)
org_loss
=
loss
(
y
,
batch
)
optimizer
=
torch
.
optim
.
AdamW
(
org_model
.
parameters
(),
lr
=
1e-3
)
optimizer
=
torch
.
optim
.
AdamW
(
org_model
.
parameters
(),
lr
=
1e-3
)
schedule
=
OneForwardOneBackwardSchedule
(
num_microbatches
,
stage_manager
)
schedule
=
OneForwardOneBackwardSchedule
(
num_microbatches
,
stage_manager
)
shard_config
=
ShardConfig
(
enable_fused_normalization
=
enable_fused_normalization
,
shard_config
=
ShardConfig
(
enable_fused_normalization
=
enable_fused_normalization
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment