"vscode:/vscode.git/clone" did not exist on "f53f4198c36d0a943de598ad91a20baa9481c5c5"
Unverified Commit eef72ed6 authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

remove unused imports and lint (#4475)

* remove unused imports and lint

* fix schedule.py

* address PR comments
parent 2310bc34
......@@ -19,12 +19,9 @@ from __future__ import division
from __future__ import print_function
import math
import time
import tensorflow as tf
from official.transformer.utils import dataset
_TRAIN, _EVAL = tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL
......@@ -123,7 +120,6 @@ class Manager(object):
Args:
num_epochs: An integer of the number of epochs to convert to steps.
batch_size: The mini-batch size used.
mode: The estimator ModeKey of the computation
Returns:
......
......@@ -14,14 +14,7 @@
# ==============================================================================
"""Functions specific to running TensorFlow on TPUs."""
import time
import tensorflow as tf
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
# "local" is a magic word in the TPU cluster resolver; it informs the resolver
......@@ -84,7 +77,7 @@ def construct_scalar_host_call(metric_dict, model_dir, prefix=""):
return host_call_fn, [global_step_tensor] + other_tensors
def embedding_matmul(embedding_table, values, mask, name='embedding_matmul'):
def embedding_matmul(embedding_table, values, mask, name="embedding_matmul"):
"""Performs embedding lookup via a matmul.
The matrix to be multiplied by the embedding table Tensor is constructed
......@@ -104,21 +97,19 @@ def embedding_matmul(embedding_table, values, mask, name='embedding_matmul'):
Rank 3 tensor of embedding vectors.
"""
with ops.name_scope(name):
n_embeddings, embedding_dim = embedding_table.get_shape().as_list()
with tf.name_scope(name):
n_embeddings = embedding_table.get_shape().as_list()[0]
batch_size, padded_size = values.shape.as_list()
emb_idcs = array_ops.tile(
array_ops.reshape(values, (batch_size, padded_size, 1)), (1, 1,
n_embeddings))
emb_weights = array_ops.tile(
array_ops.reshape(mask, (batch_size, padded_size, 1)),
(1, 1, n_embeddings))
col_idcs = array_ops.tile(
array_ops.reshape(math_ops.range(n_embeddings), (1, 1, n_embeddings)),
emb_idcs = tf.tile(
tf.reshape(values, (batch_size, padded_size, 1)), (1, 1, n_embeddings))
emb_weights = tf.tile(
tf.reshape(mask, (batch_size, padded_size, 1)), (1, 1, n_embeddings))
col_idcs = tf.tile(
tf.reshape(tf.range(n_embeddings), (1, 1, n_embeddings)),
(batch_size, padded_size, 1))
one_hot = array_ops.where(
math_ops.equal(emb_idcs, col_idcs), emb_weights,
array_ops.zeros((batch_size, padded_size, n_embeddings)))
one_hot = tf.where(
tf.equal(emb_idcs, col_idcs), emb_weights,
tf.zeros((batch_size, padded_size, n_embeddings)))
return math_ops.tensordot(one_hot, embedding_table, 1)
return tf.tensordot(one_hot, embedding_table, 1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment