Unverified Commit 6050f377 authored by wukong1992's avatar wukong1992 Committed by GitHub
Browse files

[booster] removed models that don't support fsdp (#3744)


Co-authored-by: default avatar纪少敏 <jishaomin@jishaomindeMBP.lan>
parent afb239bb
......@@ -46,7 +46,10 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn):
def check_torch_fsdp_plugin():
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
if 'diffusers' in name:
if any(element in name for element in [
'diffusers', 'deepfm_sparsearch', 'dlrm_interactionarch', 'torchvision_googlenet',
'torchvision_inception_v3'
]):
continue
run_fn(model_fn, data_gen_fn, output_transform_fn)
torch.cuda.empty_cache()
......@@ -58,12 +61,6 @@ def run_dist(rank, world_size, port):
check_torch_fsdp_plugin()
# FIXME: this test is not working
@pytest.mark.skip(
"ValueError: expected to be in states [<TrainingState_.BACKWARD_PRE: 3>, <TrainingState_.BACKWARD_POST: 4>] but current state is TrainingState_.IDLE"
)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason="requires torch1.12 or higher")
@rerun_if_address_is_in_use()
def test_torch_fsdp_plugin():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment