Commit 5e5e6f6e authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 443735362
parent 14c32065
......@@ -77,8 +77,8 @@ class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler):
tf.zeros(input_length, tf.int32))
num_sampled_pos = tf.reduce_sum(
input_tensor=tf.cast(valid_positive_index, tf.int32))
max_num_positive_samples = tf.constant(
int(sample_size * self._positive_fraction), tf.int32)
max_num_positive_samples = tf.cast(
tf.cast(sample_size, tf.float32) * self._positive_fraction, tf.int32)
num_positive_samples = tf.minimum(max_num_positive_samples, num_sampled_pos)
num_negative_samples = tf.constant(sample_size,
tf.int32) - num_positive_samples
......@@ -219,7 +219,7 @@ class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler):
indicator: boolean tensor of shape [N] whose True entries can be sampled.
batch_size: desired batch size. If None, keeps all positive samples and
randomly selects negative samples so that the positive sample fraction
matches self._positive_fraction. It cannot be None is is_static is True.
matches self._positive_fraction. It cannot be None if is_static is True.
labels: boolean tensor of shape [N] denoting positive(=True) and negative
(=False) examples.
scope: name scope.
......@@ -259,7 +259,9 @@ class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler):
max_num_pos = tf.reduce_sum(
input_tensor=tf.cast(positive_idx, dtype=tf.int32))
else:
max_num_pos = int(self._positive_fraction * batch_size)
max_num_pos = tf.cast(
self._positive_fraction * tf.cast(batch_size, tf.float32),
tf.int32)
sampled_pos_idx = self.subsample_indicator(positive_idx, max_num_pos)
num_sampled_pos = tf.reduce_sum(
input_tensor=tf.cast(sampled_pos_idx, tf.int32))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment