Commit da3cef27 authored by Baizhou Zhang's avatar Baizhou Zhang Committed by Hongxin Liu
Browse files

[pipeline] fix return_dict/fix pure_pipeline_test (#4331)

parent 411cf1d2
import warnings
from typing import Any, Dict, List, Optional, Tuple
import torch
......@@ -277,9 +278,6 @@ class BertPipelineForwards:
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
outputs = BertPipelineForwards.bert_model_forward(
self.bert,
......@@ -387,9 +385,6 @@ class BertPipelineForwards:
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
outputs = BertPipelineForwards.bert_model_forward(
self.bert,
......@@ -478,9 +473,6 @@ class BertPipelineForwards:
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
outputs = BertPipelineForwards.bert_model_forward(
self.bert,
......@@ -579,16 +571,15 @@ class BertPipelineForwards:
FutureWarning,
)
labels = kwargs.pop("next_sentence_label")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = BertPipelineForwards.bert_model_forward(self.bert,
input_ids,
......@@ -661,10 +652,6 @@ class BertPipelineForwards:
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = BertPipelineForwards.bert_model_forward(self.bert,
input_ids,
......@@ -753,10 +740,6 @@ class BertPipelineForwards:
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = BertPipelineForwards.bert_model_forward(
self.bert,
......@@ -832,10 +815,6 @@ class BertPipelineForwards:
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# in our pipeline design,input ids are copied for every stage and shouldn't be none
# the input_ids for multiple choice model is [batch_size, num_choices, sequence_length]
......@@ -928,10 +907,6 @@ class BertPipelineForwards:
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = BertPipelineForwards.bert_model_forward(
self.bert,
......
......@@ -313,9 +313,6 @@ class BloomPipelineForwards:
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
transformer_outputs = BloomPipelineForwards.bloom_model_forward(self.transformer,
input_ids,
......@@ -411,9 +408,6 @@ class BloomPipelineForwards:
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
transformer_outputs = BloomPipelineForwards.bloom_model_forward(
self.transformer,
......@@ -537,9 +531,6 @@ class BloomPipelineForwards:
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
transformer_outputs = BloomPipelineForwards.bloom_model_forward(
self.transformer,
......@@ -626,9 +617,6 @@ class BloomPipelineForwards:
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
outputs = BloomPipelineForwards.bloom_model_forward(
self.transformer,
......
......@@ -52,6 +52,8 @@ class GPT2PipelineForwards:
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
# Please refer to original code of transformers for more details.
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
logger = logging.get_logger(__name__)
# Preprocess passed in arguments
......
......@@ -8,6 +8,18 @@ import torch
import torch.nn as nn
from torch import Tensor, nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
)
from transformers.models.opt.modeling_opt import (
OPTForCausalLM,
OPTForQuestionAnswering,
OPTForSequenceClassification,
OPTModel,
)
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
......@@ -317,7 +329,7 @@ class OPTPipelineForwards:
@staticmethod
def opt_model_forward(
self: 'OPTModel',
self: OPTModel,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
......@@ -330,7 +342,7 @@ class OPTPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
) -> Union[Tuple, 'BaseModelOutputWithPast']:
) -> Union[Tuple, BaseModelOutputWithPast]:
'''
This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward
'''
......@@ -506,7 +518,7 @@ class OPTPipelineForwards:
@staticmethod
def opt_for_causal_lm_forward(
self: 'OPTForCausalLM',
self: OPTForCausalLM,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
......@@ -520,7 +532,7 @@ class OPTPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
) -> Union[Tuple, 'CausalLMOutputWithPast']:
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
......@@ -646,7 +658,7 @@ class OPTPipelineForwards:
@staticmethod
def opt_for_sequence_classification_forward(
self: 'OPTForSequenceClassification',
self: OPTForSequenceClassification,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
......@@ -660,7 +672,7 @@ class OPTPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
) -> Union[Tuple, 'SequenceClassifierOutputWithPast']:
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
......@@ -746,7 +758,7 @@ class OPTPipelineForwards:
@staticmethod
def opt_for_question_answering_forward(
self: 'OPTForQuestionAnswering',
self: OPTForQuestionAnswering,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
......@@ -761,7 +773,7 @@ class OPTPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
) -> Union[Tuple, 'QuestionAnsweringModelOutput']:
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
......
import copy
import random
from contextlib import nullcontext
from typing import Any, Callable, Iterator, List, Optional, Tuple
import numpy as np
......@@ -100,8 +99,8 @@ class data_loader():
return torch.ones((4, 128), dtype=torch.int).cuda() * 10
def loss(x, y):
return (x[0].float().mean() - y[0].float().mean())
def loss(y, x):
return (y[0].float().mean() - x[0].float().mean())
@parameterize('enable_fused_normalization', [False])
......@@ -137,7 +136,7 @@ def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_la
batch = next(data_iter)
with torch.no_grad():
y = model_copy(batch)
org_loss = loss(batch, y)
org_loss = loss(y, batch)
optimizer = torch.optim.AdamW(org_model.parameters(), lr=1e-3)
schedule = OneForwardOneBackwardSchedule(num_microbatches, stage_manager)
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment