Unverified Commit 48795317 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

[test fetcher] Always include the directly related test files (#30050)



* fix

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent de11d0bd
...@@ -958,10 +958,25 @@ def create_module_to_test_map( ...@@ -958,10 +958,25 @@ def create_module_to_test_map(
model_tests = {Path(t).parts[2] for t in tests if t.startswith("tests/models/")} model_tests = {Path(t).parts[2] for t in tests if t.startswith("tests/models/")}
return len(model_tests) > num_model_tests // 2 return len(model_tests) > num_model_tests // 2
def filter_tests(tests): # for each module (if specified in the argument `module`) of the form `models/my_model` (i.e. starting with it),
return [t for t in tests if not t.startswith("tests/models/") or Path(t).parts[2] in IMPORTANT_MODELS] # we always keep the tests (those are already in the argument `tests`) which are in `tests/models/my_model`.
# This is to avoid them being excluded when a module has many impacted tests: the directly related test files should
# always be included!
def filter_tests(tests, module=""):
return [
t
for t in tests
if not t.startswith("tests/models/")
or Path(t).parts[2] in IMPORTANT_MODELS
# at this point, `t` is of the form `tests/models/my_model`, and we check if `models/my_model`
# (i.e. `parts[1:3]`) is in `module`.
or "/".join(Path(t).parts[1:3]) in module
]
return {module: (filter_tests(tests) if has_many_models(tests) else tests) for module, tests in test_map.items()} return {
module: (filter_tests(tests, module=module) if has_many_models(tests) else tests)
for module, tests in test_map.items()
}
def check_imports_all_exist(): def check_imports_all_exist():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment