Commit da3cef27 authored by Baizhou Zhang's avatar Baizhou Zhang Committed by Hongxin Liu
Browse files

[pipeline] fix return_dict/fix pure_pipeline_test (#4331)

parent 411cf1d2
import warnings
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import torch import torch
...@@ -277,9 +278,6 @@ class BertPipelineForwards: ...@@ -277,9 +278,6 @@ class BertPipelineForwards:
if output_hidden_states: if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
outputs = BertPipelineForwards.bert_model_forward( outputs = BertPipelineForwards.bert_model_forward(
self.bert, self.bert,
...@@ -387,9 +385,6 @@ class BertPipelineForwards: ...@@ -387,9 +385,6 @@ class BertPipelineForwards:
if output_hidden_states: if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
outputs = BertPipelineForwards.bert_model_forward( outputs = BertPipelineForwards.bert_model_forward(
self.bert, self.bert,
...@@ -478,9 +473,6 @@ class BertPipelineForwards: ...@@ -478,9 +473,6 @@ class BertPipelineForwards:
if output_hidden_states: if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
outputs = BertPipelineForwards.bert_model_forward( outputs = BertPipelineForwards.bert_model_forward(
self.bert, self.bert,
...@@ -579,16 +571,15 @@ class BertPipelineForwards: ...@@ -579,16 +571,15 @@ class BertPipelineForwards:
FutureWarning, FutureWarning,
) )
labels = kwargs.pop("next_sentence_label") labels = kwargs.pop("next_sentence_label")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if output_attentions: if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False output_attentions = False
if output_hidden_states: if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = BertPipelineForwards.bert_model_forward(self.bert, outputs = BertPipelineForwards.bert_model_forward(self.bert,
input_ids, input_ids,
...@@ -661,10 +652,6 @@ class BertPipelineForwards: ...@@ -661,10 +652,6 @@ class BertPipelineForwards:
if output_hidden_states: if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = BertPipelineForwards.bert_model_forward(self.bert, outputs = BertPipelineForwards.bert_model_forward(self.bert,
input_ids, input_ids,
...@@ -753,10 +740,6 @@ class BertPipelineForwards: ...@@ -753,10 +740,6 @@ class BertPipelineForwards:
if output_hidden_states: if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = BertPipelineForwards.bert_model_forward( outputs = BertPipelineForwards.bert_model_forward(
self.bert, self.bert,
...@@ -832,10 +815,6 @@ class BertPipelineForwards: ...@@ -832,10 +815,6 @@ class BertPipelineForwards:
if output_hidden_states: if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# in our pipeline design,input ids are copied for every stage and shouldn't be none # in our pipeline design,input ids are copied for every stage and shouldn't be none
# the input_ids for multiple choice model is [batch_size, num_choices, sequence_length] # the input_ids for multiple choice model is [batch_size, num_choices, sequence_length]
...@@ -928,10 +907,6 @@ class BertPipelineForwards: ...@@ -928,10 +907,6 @@ class BertPipelineForwards:
if output_hidden_states: if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = BertPipelineForwards.bert_model_forward( outputs = BertPipelineForwards.bert_model_forward(
self.bert, self.bert,
......
...@@ -313,9 +313,6 @@ class BloomPipelineForwards: ...@@ -313,9 +313,6 @@ class BloomPipelineForwards:
if output_hidden_states: if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
transformer_outputs = BloomPipelineForwards.bloom_model_forward(self.transformer, transformer_outputs = BloomPipelineForwards.bloom_model_forward(self.transformer,
input_ids, input_ids,
...@@ -411,9 +408,6 @@ class BloomPipelineForwards: ...@@ -411,9 +408,6 @@ class BloomPipelineForwards:
if output_hidden_states: if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
transformer_outputs = BloomPipelineForwards.bloom_model_forward( transformer_outputs = BloomPipelineForwards.bloom_model_forward(
self.transformer, self.transformer,
...@@ -537,9 +531,6 @@ class BloomPipelineForwards: ...@@ -537,9 +531,6 @@ class BloomPipelineForwards:
if output_hidden_states: if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
transformer_outputs = BloomPipelineForwards.bloom_model_forward( transformer_outputs = BloomPipelineForwards.bloom_model_forward(
self.transformer, self.transformer,
...@@ -626,9 +617,6 @@ class BloomPipelineForwards: ...@@ -626,9 +617,6 @@ class BloomPipelineForwards:
if output_hidden_states: if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
outputs = BloomPipelineForwards.bloom_model_forward( outputs = BloomPipelineForwards.bloom_model_forward(
self.transformer, self.transformer,
......
...@@ -52,6 +52,8 @@ class GPT2PipelineForwards: ...@@ -52,6 +52,8 @@ class GPT2PipelineForwards:
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
# Please refer to original code of transformers for more details. # Please refer to original code of transformers for more details.
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
# Preprocess passed in arguments # Preprocess passed in arguments
......
...@@ -8,6 +8,18 @@ import torch ...@@ -8,6 +8,18 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor, nn from torch import Tensor, nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
)
from transformers.models.opt.modeling_opt import (
OPTForCausalLM,
OPTForQuestionAnswering,
OPTForSequenceClassification,
OPTModel,
)
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
...@@ -317,7 +329,7 @@ class OPTPipelineForwards: ...@@ -317,7 +329,7 @@ class OPTPipelineForwards:
@staticmethod @staticmethod
def opt_model_forward( def opt_model_forward(
self: 'OPTModel', self: OPTModel,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
...@@ -330,7 +342,7 @@ class OPTPipelineForwards: ...@@ -330,7 +342,7 @@ class OPTPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
) -> Union[Tuple, 'BaseModelOutputWithPast']: ) -> Union[Tuple, BaseModelOutputWithPast]:
''' '''
This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward
''' '''
...@@ -506,7 +518,7 @@ class OPTPipelineForwards: ...@@ -506,7 +518,7 @@ class OPTPipelineForwards:
@staticmethod @staticmethod
def opt_for_causal_lm_forward( def opt_for_causal_lm_forward(
self: 'OPTForCausalLM', self: OPTForCausalLM,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
...@@ -520,7 +532,7 @@ class OPTPipelineForwards: ...@@ -520,7 +532,7 @@ class OPTPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
) -> Union[Tuple, 'CausalLMOutputWithPast']: ) -> Union[Tuple, CausalLMOutputWithPast]:
r""" r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
...@@ -646,7 +658,7 @@ class OPTPipelineForwards: ...@@ -646,7 +658,7 @@ class OPTPipelineForwards:
@staticmethod @staticmethod
def opt_for_sequence_classification_forward( def opt_for_sequence_classification_forward(
self: 'OPTForSequenceClassification', self: OPTForSequenceClassification,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
...@@ -660,7 +672,7 @@ class OPTPipelineForwards: ...@@ -660,7 +672,7 @@ class OPTPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
) -> Union[Tuple, 'SequenceClassifierOutputWithPast']: ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
...@@ -746,7 +758,7 @@ class OPTPipelineForwards: ...@@ -746,7 +758,7 @@ class OPTPipelineForwards:
@staticmethod @staticmethod
def opt_for_question_answering_forward( def opt_for_question_answering_forward(
self: 'OPTForQuestionAnswering', self: OPTForQuestionAnswering,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
...@@ -761,7 +773,7 @@ class OPTPipelineForwards: ...@@ -761,7 +773,7 @@ class OPTPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
) -> Union[Tuple, 'QuestionAnsweringModelOutput']: ) -> Union[Tuple, QuestionAnsweringModelOutput]:
r""" r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.
......
import copy import copy
import random import random
from contextlib import nullcontext
from typing import Any, Callable, Iterator, List, Optional, Tuple from typing import Any, Callable, Iterator, List, Optional, Tuple
import numpy as np import numpy as np
...@@ -100,8 +99,8 @@ class data_loader(): ...@@ -100,8 +99,8 @@ class data_loader():
return torch.ones((4, 128), dtype=torch.int).cuda() * 10 return torch.ones((4, 128), dtype=torch.int).cuda() * 10
def loss(x, y): def loss(y, x):
return (x[0].float().mean() - y[0].float().mean()) return (y[0].float().mean() - x[0].float().mean())
@parameterize('enable_fused_normalization', [False]) @parameterize('enable_fused_normalization', [False])
...@@ -137,7 +136,7 @@ def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_la ...@@ -137,7 +136,7 @@ def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_la
batch = next(data_iter) batch = next(data_iter)
with torch.no_grad(): with torch.no_grad():
y = model_copy(batch) y = model_copy(batch)
org_loss = loss(batch, y) org_loss = loss(y, batch)
optimizer = torch.optim.AdamW(org_model.parameters(), lr=1e-3) optimizer = torch.optim.AdamW(org_model.parameters(), lr=1e-3)
schedule = OneForwardOneBackwardSchedule(num_microbatches, stage_manager) schedule = OneForwardOneBackwardSchedule(num_microbatches, stage_manager)
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment