Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
da3cef27
Commit
da3cef27
authored
Jul 27, 2023
by
Baizhou Zhang
Committed by
Hongxin Liu
Aug 15, 2023
Browse files
[pipeline] fix return_dict/fix pure_pipeline_test (#4331)
parent
411cf1d2
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
29 additions
and
53 deletions
+29
-53
colossalai/shardformer/modeling/bert.py
colossalai/shardformer/modeling/bert.py
+4
-29
colossalai/shardformer/modeling/bloom.py
colossalai/shardformer/modeling/bloom.py
+0
-12
colossalai/shardformer/modeling/gpt2.py
colossalai/shardformer/modeling/gpt2.py
+2
-0
colossalai/shardformer/policies/opt.py
colossalai/shardformer/policies/opt.py
+20
-8
tests/test_shardformer/test_model/test_pure_pipeline.py
tests/test_shardformer/test_model/test_pure_pipeline.py
+3
-4
No files found.
colossalai/shardformer/modeling/bert.py
View file @
da3cef27
import
warnings
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
torch
...
...
@@ -277,9 +278,6 @@ class BertPipelineForwards:
if
output_hidden_states
:
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
output_hidden_states
=
False
if
return_dict
:
logger
.
warning_once
(
'return_dict is not supported for pipeline models at the moment'
)
return_dict
=
False
outputs
=
BertPipelineForwards
.
bert_model_forward
(
self
.
bert
,
...
...
@@ -387,9 +385,6 @@ class BertPipelineForwards:
if
output_hidden_states
:
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
output_hidden_states
=
False
if
return_dict
:
logger
.
warning_once
(
'return_dict is not supported for pipeline models at the moment'
)
return_dict
=
False
outputs
=
BertPipelineForwards
.
bert_model_forward
(
self
.
bert
,
...
...
@@ -478,9 +473,6 @@ class BertPipelineForwards:
if
output_hidden_states
:
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
output_hidden_states
=
False
if
return_dict
:
logger
.
warning_once
(
'return_dict is not supported for pipeline models at the moment'
)
return_dict
=
False
outputs
=
BertPipelineForwards
.
bert_model_forward
(
self
.
bert
,
...
...
@@ -579,16 +571,15 @@ class BertPipelineForwards:
FutureWarning
,
)
labels
=
kwargs
.
pop
(
"next_sentence_label"
)
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
if
output_attentions
:
logger
.
warning_once
(
'output_attentions=True is not supported for pipeline models at the moment.'
)
output_attentions
=
False
if
output_hidden_states
:
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
output_hidden_states
=
False
if
return_dict
:
logger
.
warning_once
(
'return_dict is not supported for pipeline models at the moment'
)
return_dict
=
False
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
outputs
=
BertPipelineForwards
.
bert_model_forward
(
self
.
bert
,
input_ids
,
...
...
@@ -661,10 +652,6 @@ class BertPipelineForwards:
if
output_hidden_states
:
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
output_hidden_states
=
False
if
return_dict
:
logger
.
warning_once
(
'return_dict is not supported for pipeline models at the moment'
)
return_dict
=
False
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
outputs
=
BertPipelineForwards
.
bert_model_forward
(
self
.
bert
,
input_ids
,
...
...
@@ -753,10 +740,6 @@ class BertPipelineForwards:
if
output_hidden_states
:
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
output_hidden_states
=
False
if
return_dict
:
logger
.
warning_once
(
'return_dict is not supported for pipeline models at the moment'
)
return_dict
=
False
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
outputs
=
BertPipelineForwards
.
bert_model_forward
(
self
.
bert
,
...
...
@@ -832,10 +815,6 @@ class BertPipelineForwards:
if
output_hidden_states
:
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
output_hidden_states
=
False
if
return_dict
:
logger
.
warning_once
(
'return_dict is not supported for pipeline models at the moment'
)
return_dict
=
False
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
# in our pipeline design,input ids are copied for every stage and shouldn't be none
# the input_ids for multiple choice model is [batch_size, num_choices, sequence_length]
...
...
@@ -928,10 +907,6 @@ class BertPipelineForwards:
if
output_hidden_states
:
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
output_hidden_states
=
False
if
return_dict
:
logger
.
warning_once
(
'return_dict is not supported for pipeline models at the moment'
)
return_dict
=
False
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
outputs
=
BertPipelineForwards
.
bert_model_forward
(
self
.
bert
,
...
...
colossalai/shardformer/modeling/bloom.py
View file @
da3cef27
...
...
@@ -313,9 +313,6 @@ class BloomPipelineForwards:
if
output_hidden_states
:
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
output_hidden_states
=
False
if
return_dict
:
logger
.
warning_once
(
'return_dict is not supported for pipeline models at the moment'
)
return_dict
=
False
transformer_outputs
=
BloomPipelineForwards
.
bloom_model_forward
(
self
.
transformer
,
input_ids
,
...
...
@@ -411,9 +408,6 @@ class BloomPipelineForwards:
if
output_hidden_states
:
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
output_hidden_states
=
False
if
return_dict
:
logger
.
warning_once
(
'return_dict is not supported for pipeline models at the moment'
)
return_dict
=
False
transformer_outputs
=
BloomPipelineForwards
.
bloom_model_forward
(
self
.
transformer
,
...
...
@@ -537,9 +531,6 @@ class BloomPipelineForwards:
if
output_hidden_states
:
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
output_hidden_states
=
False
if
return_dict
:
logger
.
warning_once
(
'return_dict is not supported for pipeline models at the moment'
)
return_dict
=
False
transformer_outputs
=
BloomPipelineForwards
.
bloom_model_forward
(
self
.
transformer
,
...
...
@@ -626,9 +617,6 @@ class BloomPipelineForwards:
if
output_hidden_states
:
logger
.
warning_once
(
'output_hidden_states=True is not supported for pipeline models at the moment.'
)
output_hidden_states
=
False
if
return_dict
:
logger
.
warning_once
(
'return_dict is not supported for pipeline models at the moment'
)
return_dict
=
False
outputs
=
BloomPipelineForwards
.
bloom_model_forward
(
self
.
transformer
,
...
...
colossalai/shardformer/modeling/gpt2.py
View file @
da3cef27
...
...
@@ -52,6 +52,8 @@ class GPT2PipelineForwards:
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
# Please refer to original code of transformers for more details.
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
logger
=
logging
.
get_logger
(
__name__
)
# Preprocess passed in arguments
...
...
colossalai/shardformer/policies/opt.py
View file @
da3cef27
...
...
@@ -8,6 +8,18 @@ import torch
import
torch.nn
as
nn
from
torch
import
Tensor
,
nn
from
torch.nn
import
BCEWithLogitsLoss
,
CrossEntropyLoss
,
MSELoss
from
transformers.modeling_outputs
import
(
BaseModelOutputWithPast
,
CausalLMOutputWithPast
,
QuestionAnsweringModelOutput
,
SequenceClassifierOutputWithPast
,
)
from
transformers.models.opt.modeling_opt
import
(
OPTForCausalLM
,
OPTForQuestionAnswering
,
OPTForSequenceClassification
,
OPTModel
,
)
from
colossalai.pipeline.stage_manager
import
PipelineStageManager
from
colossalai.shardformer.layer
import
FusedLayerNorm
,
Linear1D_Col
,
Linear1D_Row
,
VocabParallelEmbedding1D
...
...
@@ -317,7 +329,7 @@ class OPTPipelineForwards:
@
staticmethod
def
opt_model_forward
(
self
:
'
OPTModel
'
,
self
:
OPTModel
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -330,7 +342,7 @@ class OPTPipelineForwards:
stage_manager
:
Optional
[
PipelineStageManager
]
=
None
,
hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
stage_index
:
Optional
[
List
[
int
]]
=
None
,
)
->
Union
[
Tuple
,
'
BaseModelOutputWithPast
'
]:
)
->
Union
[
Tuple
,
BaseModelOutputWithPast
]:
'''
This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward
'''
...
...
@@ -506,7 +518,7 @@ class OPTPipelineForwards:
@
staticmethod
def
opt_for_causal_lm_forward
(
self
:
'
OPTForCausalLM
'
,
self
:
OPTForCausalLM
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -520,7 +532,7 @@ class OPTPipelineForwards:
stage_manager
:
Optional
[
PipelineStageManager
]
=
None
,
hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
stage_index
:
Optional
[
List
[
int
]]
=
None
,
)
->
Union
[
Tuple
,
'
CausalLMOutputWithPast
'
]:
)
->
Union
[
Tuple
,
CausalLMOutputWithPast
]:
r
"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
...
...
@@ -646,7 +658,7 @@ class OPTPipelineForwards:
@
staticmethod
def
opt_for_sequence_classification_forward
(
self
:
'
OPTForSequenceClassification
'
,
self
:
OPTForSequenceClassification
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
head_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
...
...
@@ -660,7 +672,7 @@ class OPTPipelineForwards:
stage_manager
:
Optional
[
PipelineStageManager
]
=
None
,
hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
stage_index
:
Optional
[
List
[
int
]]
=
None
,
)
->
Union
[
Tuple
,
'
SequenceClassifierOutputWithPast
'
]:
)
->
Union
[
Tuple
,
SequenceClassifierOutputWithPast
]:
r
"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
...
...
@@ -746,7 +758,7 @@ class OPTPipelineForwards:
@
staticmethod
def
opt_for_question_answering_forward
(
self
:
'
OPTForQuestionAnswering
'
,
self
:
OPTForQuestionAnswering
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
head_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
...
...
@@ -761,7 +773,7 @@ class OPTPipelineForwards:
stage_manager
:
Optional
[
PipelineStageManager
]
=
None
,
hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
stage_index
:
Optional
[
List
[
int
]]
=
None
,
)
->
Union
[
Tuple
,
'
QuestionAnsweringModelOutput
'
]:
)
->
Union
[
Tuple
,
QuestionAnsweringModelOutput
]:
r
"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
...
...
tests/test_shardformer/test_model/test_pure_pipeline.py
View file @
da3cef27
import
copy
import
random
from
contextlib
import
nullcontext
from
typing
import
Any
,
Callable
,
Iterator
,
List
,
Optional
,
Tuple
import
numpy
as
np
...
...
@@ -100,8 +99,8 @@ class data_loader():
return
torch
.
ones
((
4
,
128
),
dtype
=
torch
.
int
).
cuda
()
*
10
def
loss
(
x
,
y
):
return
(
x
[
0
].
float
().
mean
()
-
y
[
0
].
float
().
mean
())
def
loss
(
y
,
x
):
return
(
y
[
0
].
float
().
mean
()
-
x
[
0
].
float
().
mean
())
@
parameterize
(
'enable_fused_normalization'
,
[
False
])
...
...
@@ -137,7 +136,7 @@ def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_la
batch
=
next
(
data_iter
)
with
torch
.
no_grad
():
y
=
model_copy
(
batch
)
org_loss
=
loss
(
batch
,
y
)
org_loss
=
loss
(
y
,
batch
)
optimizer
=
torch
.
optim
.
AdamW
(
org_model
.
parameters
(),
lr
=
1e-3
)
schedule
=
OneForwardOneBackwardSchedule
(
num_microbatches
,
stage_manager
)
shard_config
=
ShardConfig
(
enable_fused_normalization
=
enable_fused_normalization
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment