Commit a14d3520 authored by Baizhou Zhang's avatar Baizhou Zhang Committed by Hongxin Liu
Browse files

[pipeline] add pipeline forward for variants of gpt2 (#4238)

* add forward for GPTLMHeadModel

* add test for gpt_lm

* arranging get_held_layers method

* arrange forward replacement

* add forward for GPT2ForTokenClassification

* add forward for GPT2ForSequenceClassification

* fix test_shard_gpt2.py

* add GPT2DoubleHeadsmodel & fix bugs

* add id checking in get_shared_params
parent 7e4de520
This diff is collapsed.
...@@ -5,15 +5,9 @@ import colossalai ...@@ -5,15 +5,9 @@ import colossalai
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import ( from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward from tests.test_shardformer.test_model._utils import build_pipeline_model
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
...@@ -21,8 +15,8 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo ...@@ -21,8 +15,8 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
pass pass
@parameterize('enable_fused_normalization', [False])
@parameterize('enable_tensor_parallelism', [False]) @parameterize('enable_tensor_parallelism', [False])
@parameterize('enable_fused_normalization', [False])
@parameterize('use_lazy_init', [False]) @parameterize('use_lazy_init', [False])
#TODO: merge this into test_shard_gpt2 #TODO: merge this into test_shard_gpt2
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
...@@ -32,30 +26,30 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz ...@@ -32,30 +26,30 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
stage_manager = PipelineStageManager(pg_mesh, PP_DIM) stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
if name != "transformers_gpt":
continue
inputs = data_gen_fn() inputs = data_gen_fn()
inputs = {k: v.cuda() for k, v in inputs.items()} inputs = {k: v.cuda() for k, v in inputs.items()}
input_ids, _ = inputs['input_ids'], inputs['attention_mask']
batch_size, seq_len = input_ids.shape
hidden_size = 768
hidden_state_shape = (batch_size, seq_len, hidden_size)
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, if not stage_manager.is_first_stage():
enable_tensor_parallelism, use_lazy_init) # change inputs if not the first stage
org_model.train()
org_output = org_model(**inputs)
hidden_state_shape = org_output['last_hidden_state'].shape
if stage_manager.is_first_stage():
output = sharded_model(**inputs)
assert output['hidden_states'].shape == hidden_state_shape
else:
attention_mask = inputs['attention_mask']
hidden_states = torch.zeros(*hidden_state_shape).cuda() hidden_states = torch.zeros(*hidden_state_shape).cuda()
output = sharded_model(hidden_states=hidden_states, attention_mask=attention_mask) inputs['input_ids'] = None
if stage_manager.is_last_stage(): inputs['hidden_states'] = hidden_states
assert output['last_hidden_state'].shape == hidden_state_shape
else: _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
assert output['hidden_states'].shape == hidden_state_shape enable_tensor_parallelism, use_lazy_init)
sharded_model.train()
output = sharded_model(**inputs)
if stage_manager.is_last_stage():
if name != 'transformers_gpt':
assert output.loss is not None
else:
assert output['hidden_states'].shape == hidden_state_shape
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment