stat_utils.py 2.89 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Statistics utility functions of NCF."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

21
22
23
24
25
26
27
28
import atexit
from collections import deque
import multiprocessing
import os
import struct
import sys
import threading
import time
29

30
import numpy as np
31

32
from official.recommendation import popen_helper
33
34


35
36
def random_int32():
  return np.random.randint(low=0, high=np.iinfo(np.int32).max, dtype=np.int32)
37
38


39
40
41
42
43
44
45
def permutation(args):
  x, seed = args
  seed = seed or struct.unpack("<L", os.urandom(4))[0]
  state = np.random.RandomState(seed=seed)  # pylint: disable=no-member
  output = np.arange(x, dtype=np.int32)
  state.shuffle(output)
  return output
46
47


48
49
50
51
52
53
def very_slightly_biased_randint(max_val_vector):
  sample_dtype = np.uint64
  out_dtype = max_val_vector.dtype
  samples = np.random.randint(low=0, high=np.iinfo(sample_dtype).max,
                              size=max_val_vector.shape, dtype=sample_dtype)
  return np.mod(samples, max_val_vector.astype(sample_dtype)).astype(out_dtype)
54

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

def mask_duplicates(x, axis=1):  # type: (np.ndarray, int) -> np.ndarray
  """Identify duplicates from sampling with replacement.

  Args:
    x: A 2D NumPy array of samples
    axis: The axis along which to de-dupe.

  Returns:
    A NumPy array with the same shape as x with one if an element appeared
    previously along axis 1, else zero.
  """
  if axis != 1:
    raise NotImplementedError

  x_sort_ind = np.argsort(x, axis=1, kind="mergesort")
  sorted_x = x[np.arange(x.shape[0])[:, np.newaxis], x_sort_ind]

  # compute the indices needed to map values back to their original position.
  inv_x_sort_ind = np.argsort(x_sort_ind, axis=1, kind="mergesort")

  # Compute the difference of adjacent sorted elements.
  diffs = sorted_x[:, :-1] - sorted_x[:, 1:]

  # We are only interested in whether an element is zero. Therefore left padding
  # with ones to restore the original shape is sufficient.
  diffs = np.concatenate(
      [np.ones((diffs.shape[0], 1), dtype=diffs.dtype), diffs], axis=1)

  # Duplicate values will have a difference of zero. By definition the first
  # element is never a duplicate.
  return np.where(diffs[np.arange(x.shape[0])[:, np.newaxis],
                        inv_x_sort_ind], 0, 1)