Unverified Commit de13a951 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Flax] Update no init test for Flax v0.7.1 (#28735)

parent abe0289e
...@@ -984,7 +984,7 @@ class FlaxModelTesterMixin: ...@@ -984,7 +984,7 @@ class FlaxModelTesterMixin:
# Check if we params can be properly initialized when calling init_weights # Check if we params can be properly initialized when calling init_weights
params = model.init_weights(model.key, model.input_shape) params = model.init_weights(model.key, model.input_shape)
self.assertIsInstance(params, FrozenDict) assert isinstance(params, (dict, FrozenDict)), f"params are not an instance of {FrozenDict}"
# Check if all required parmas are initialized # Check if all required parmas are initialized
keys = set(flatten_dict(unfreeze(params)).keys()) keys = set(flatten_dict(unfreeze(params)).keys())
self.assertTrue(all(k in keys for k in model.required_params)) self.assertTrue(all(k in keys for k in model.required_params))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment