Commit 5e5e6f6e authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 443735362
parent 14c32065
...@@ -77,8 +77,8 @@ class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler): ...@@ -77,8 +77,8 @@ class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler):
tf.zeros(input_length, tf.int32)) tf.zeros(input_length, tf.int32))
num_sampled_pos = tf.reduce_sum( num_sampled_pos = tf.reduce_sum(
input_tensor=tf.cast(valid_positive_index, tf.int32)) input_tensor=tf.cast(valid_positive_index, tf.int32))
max_num_positive_samples = tf.constant( max_num_positive_samples = tf.cast(
int(sample_size * self._positive_fraction), tf.int32) tf.cast(sample_size, tf.float32) * self._positive_fraction, tf.int32)
num_positive_samples = tf.minimum(max_num_positive_samples, num_sampled_pos) num_positive_samples = tf.minimum(max_num_positive_samples, num_sampled_pos)
num_negative_samples = tf.constant(sample_size, num_negative_samples = tf.constant(sample_size,
tf.int32) - num_positive_samples tf.int32) - num_positive_samples
...@@ -219,7 +219,7 @@ class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler): ...@@ -219,7 +219,7 @@ class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler):
indicator: boolean tensor of shape [N] whose True entries can be sampled. indicator: boolean tensor of shape [N] whose True entries can be sampled.
batch_size: desired batch size. If None, keeps all positive samples and batch_size: desired batch size. If None, keeps all positive samples and
randomly selects negative samples so that the positive sample fraction randomly selects negative samples so that the positive sample fraction
matches self._positive_fraction. It cannot be None is is_static is True. matches self._positive_fraction. It cannot be None if is_static is True.
labels: boolean tensor of shape [N] denoting positive(=True) and negative labels: boolean tensor of shape [N] denoting positive(=True) and negative
(=False) examples. (=False) examples.
scope: name scope. scope: name scope.
...@@ -259,7 +259,9 @@ class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler): ...@@ -259,7 +259,9 @@ class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler):
max_num_pos = tf.reduce_sum( max_num_pos = tf.reduce_sum(
input_tensor=tf.cast(positive_idx, dtype=tf.int32)) input_tensor=tf.cast(positive_idx, dtype=tf.int32))
else: else:
max_num_pos = int(self._positive_fraction * batch_size) max_num_pos = tf.cast(
self._positive_fraction * tf.cast(batch_size, tf.float32),
tf.int32)
sampled_pos_idx = self.subsample_indicator(positive_idx, max_num_pos) sampled_pos_idx = self.subsample_indicator(positive_idx, max_num_pos)
num_sampled_pos = tf.reduce_sum( num_sampled_pos = tf.reduce_sum(
input_tensor=tf.cast(sampled_pos_idx, tf.int32)) input_tensor=tf.cast(sampled_pos_idx, tf.int32))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment