Commit c4928698 authored by klhhhhh's avatar klhhhhh Committed by Hongxin Liu
Browse files

[shardformer] vit test finish and support

parent f60162b2
......@@ -56,7 +56,6 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_tensor_parallelism', [True, False])
def run_vit_test(enable_fused_normalization, enable_tensor_parallelism):
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
print(sub_model_zoo)
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment