Commit 3c9198bc authored by Sai Ganesh Bandiatmakuri's avatar Sai Ganesh Bandiatmakuri Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 307070264
parent dbf2f102
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils to set Owner annotations on benchmarks.
@owner_utils.Owner('owner_team/user') can be set either at the benchmark class
level / benchmark method level or both.
Runner frameworks can use owner_utils.GetOwner(benchmark_method) to get the
actual owner. Python inheritance for the owner attribute is respected. (E.g
method level owner takes precedence over class level).
See owner_utils_test for associated tests and more examples.
The decorator can be applied both at the method level and at the class level.
Simple example:
===============
class MLBenchmark:
@Owner('example_id')
def benchmark_method_1_gpu(self):
return True
"""
def Owner(owner_name):
"""Sets the owner attribute on a decorated method or class."""
def _Wrapper(func_or_class):
"""Sets the benchmark owner attribute."""
func_or_class.__benchmark__owner__ = owner_name
return func_or_class
return _Wrapper
def GetOwner(benchmark_method_or_class):
"""Gets the inherited owner attribute for this benchmark.
Checks for existence of __benchmark__owner__. If it's not present, looks for
it in the parent class's attribute list.
Args:
benchmark_method_or_class: A benchmark method or class.
Returns:
string - the associated owner if present / None.
"""
if hasattr(benchmark_method_or_class, '__benchmark__owner__'):
return benchmark_method_or_class.__benchmark__owner__
elif hasattr(benchmark_method_or_class, '__self__'):
if hasattr(benchmark_method_or_class.__self__, '__benchmark__owner__'):
return benchmark_method_or_class.__self__.__benchmark__owner__
return None
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.benchmark.owner_utils."""
from absl.testing import absltest
from official.benchmark import owner_utils
@owner_utils.Owner('static_owner')
def static_function(foo=5):
return foo
def static_function_without_owner(foo=5):
return foo
class BenchmarkClassWithoutOwner:
def method_without_owner(self):
return 100
@owner_utils.Owner('method_owner')
def method_with_owner(self):
return 200
@owner_utils.Owner('class_owner')
class SomeBenchmarkClass:
def method_inherited_owner(self):
return 123
@owner_utils.Owner('method_owner')
def method_override_owner(self):
return 345
@owner_utils.Owner('new_class_owner')
class InheritedClass(SomeBenchmarkClass):
def method_inherited_owner(self):
return 456
@owner_utils.Owner('new_method_owner')
def method_override_owner(self):
return 567
class OwnerUtilsTest(absltest.TestCase):
"""Tests to assert for owner decorator functionality."""
def test_owner_tag_missing(self):
self.assertEqual(None, owner_utils.GetOwner(static_function_without_owner))
benchmark_class = BenchmarkClassWithoutOwner()
self.assertEqual(None,
owner_utils.GetOwner(benchmark_class.method_without_owner))
self.assertEqual(100, benchmark_class.method_without_owner())
self.assertEqual('method_owner',
owner_utils.GetOwner(benchmark_class.method_with_owner))
self.assertEqual(200, benchmark_class.method_with_owner())
def test_owner_attributes_static(self):
self.assertEqual('static_owner', owner_utils.GetOwner(static_function))
self.assertEqual(5, static_function(5))
def test_owner_attributes_per_class(self):
level1 = SomeBenchmarkClass()
self.assertEqual('class_owner',
owner_utils.GetOwner(level1.method_inherited_owner))
self.assertEqual(123, level1.method_inherited_owner())
self.assertEqual('method_owner',
owner_utils.GetOwner(level1.method_override_owner))
self.assertEqual(345, level1.method_override_owner())
def test_owner_attributes_inherited_class(self):
level2 = InheritedClass()
self.assertEqual('new_class_owner',
owner_utils.GetOwner(level2.method_inherited_owner))
self.assertEqual(456, level2.method_inherited_owner())
self.assertEqual('new_method_owner',
owner_utils.GetOwner(level2.method_override_owner))
self.assertEqual(567, level2.method_override_owner())
if __name__ == '__main__':
absltest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment