Commit 34f0e34a authored by Jianghai's avatar Jianghai Committed by Hongxin Liu
Browse files

[pipeline] finish bloom models pipeline and tests (#4223)

* bloom policy

* llama pipeline forward and tests

* fix the output and attention_mask

* fix name

* bind argument to policy

* finish bloom model

* test shard gpt2

* clear cache

* support all bloom models

* add bloom models policies

* finish bloom pipeline and tests

* add set pipeline

* finish bloom
parent e7cc62d7
This diff is collapsed.
...@@ -46,23 +46,22 @@ def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_la ...@@ -46,23 +46,22 @@ def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_la
stage_manager = PipelineStageManager(pg_mesh, PP_DIM) stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
x = torch.randint(0, 1000, (2, 3)).cuda() x = torch.randint(0, 1000, (1, 3)).cuda()
hidden_states = torch.randint(0, 1000, (2, 3, 64)).to(torch.float32).cuda() hidden_states = torch.randint(0, 1000, (1, 3, 64)).to(torch.float32).cuda()
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name == 'transformers_bloom': org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init)
enable_tensor_parallelism, use_lazy_init) if stage_manager.stage == 0:
if stage_manager.stage == 0: attention_mask = torch.ones_like(x).cuda()
attention_mask = torch.ones_like(x).cuda() output = sharded_model(input_ids=x, attention_mask=attention_mask)
output = sharded_model(input_ids=x, attention_mask=attention_mask) assert output['hidden_states'].shape == (1, 3, 64)
assert output['hidden_states'].shape == (2, 3, 64) else:
else: attention_mask = torch.ones((1, 3)).cuda()
attention_mask = torch.ones((2, 3)).cuda() output = sharded_model(
output = sharded_model( hidden_states=hidden_states,
hidden_states=hidden_states, attention_mask=attention_mask,
attention_mask=attention_mask, )
) assert output[0].shape[0] == 1
assert output[0].shape == (2, 3, 64)
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment