Commit 34f0e34a authored by Jianghai's avatar Jianghai Committed by Hongxin Liu
Browse files

[pipeline] finish bloom models pipeline and tests (#4223)

* bloom policy

* llama pipeline forward and tests

* fix the output and attention_mask

* fix name

* bind argument to policy

* finish bloom model

* test shard gpt2

* clear cache

* support all bloom models

* add bloom models policies

* finish bloom pipeline and tests

* add set pipeline

* finish bloom
parent e7cc62d7
This diff is collapsed.
......@@ -46,23 +46,22 @@ def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_la
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
x = torch.randint(0, 1000, (2, 3)).cuda()
hidden_states = torch.randint(0, 1000, (2, 3, 64)).to(torch.float32).cuda()
x = torch.randint(0, 1000, (1, 3)).cuda()
hidden_states = torch.randint(0, 1000, (1, 3, 64)).to(torch.float32).cuda()
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name == 'transformers_bloom':
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
if stage_manager.stage == 0:
attention_mask = torch.ones_like(x).cuda()
output = sharded_model(input_ids=x, attention_mask=attention_mask)
assert output['hidden_states'].shape == (2, 3, 64)
else:
attention_mask = torch.ones((2, 3)).cuda()
output = sharded_model(
hidden_states=hidden_states,
attention_mask=attention_mask,
)
assert output[0].shape == (2, 3, 64)
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
if stage_manager.stage == 0:
attention_mask = torch.ones_like(x).cuda()
output = sharded_model(input_ids=x, attention_mask=attention_mask)
assert output['hidden_states'].shape == (1, 3, 64)
else:
attention_mask = torch.ones((1, 3)).cuda()
output = sharded_model(
hidden_states=hidden_states,
attention_mask=attention_mask,
)
assert output[0].shape[0] == 1
torch.cuda.empty_cache()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment