"docs/zh_cn/vscode:/vscode.git/clone" did not exist on "5f58b9101f959a5794860b2737ab83b522d25915"
Unverified Commit eef72ed6 authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

remove unused imports and lint (#4475)

* remove unused imports and lint

* fix schedule.py

* address PR comments
parent 2310bc34
...@@ -19,12 +19,9 @@ from __future__ import division ...@@ -19,12 +19,9 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import math import math
import time
import tensorflow as tf import tensorflow as tf
from official.transformer.utils import dataset
_TRAIN, _EVAL = tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL _TRAIN, _EVAL = tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL
...@@ -123,7 +120,6 @@ class Manager(object): ...@@ -123,7 +120,6 @@ class Manager(object):
Args: Args:
num_epochs: An integer of the number of epochs to convert to steps. num_epochs: An integer of the number of epochs to convert to steps.
batch_size: The mini-batch size used.
mode: The estimator ModeKey of the computation mode: The estimator ModeKey of the computation
Returns: Returns:
......
...@@ -14,14 +14,7 @@ ...@@ -14,14 +14,7 @@
# ============================================================================== # ==============================================================================
"""Functions specific to running TensorFlow on TPUs.""" """Functions specific to running TensorFlow on TPUs."""
import time
import tensorflow as tf import tensorflow as tf
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
# "local" is a magic word in the TPU cluster resolver; it informs the resolver # "local" is a magic word in the TPU cluster resolver; it informs the resolver
...@@ -84,7 +77,7 @@ def construct_scalar_host_call(metric_dict, model_dir, prefix=""): ...@@ -84,7 +77,7 @@ def construct_scalar_host_call(metric_dict, model_dir, prefix=""):
return host_call_fn, [global_step_tensor] + other_tensors return host_call_fn, [global_step_tensor] + other_tensors
def embedding_matmul(embedding_table, values, mask, name='embedding_matmul'): def embedding_matmul(embedding_table, values, mask, name="embedding_matmul"):
"""Performs embedding lookup via a matmul. """Performs embedding lookup via a matmul.
The matrix to be multiplied by the embedding table Tensor is constructed The matrix to be multiplied by the embedding table Tensor is constructed
...@@ -104,21 +97,19 @@ def embedding_matmul(embedding_table, values, mask, name='embedding_matmul'): ...@@ -104,21 +97,19 @@ def embedding_matmul(embedding_table, values, mask, name='embedding_matmul'):
Rank 3 tensor of embedding vectors. Rank 3 tensor of embedding vectors.
""" """
with ops.name_scope(name): with tf.name_scope(name):
n_embeddings, embedding_dim = embedding_table.get_shape().as_list() n_embeddings = embedding_table.get_shape().as_list()[0]
batch_size, padded_size = values.shape.as_list() batch_size, padded_size = values.shape.as_list()
emb_idcs = array_ops.tile( emb_idcs = tf.tile(
array_ops.reshape(values, (batch_size, padded_size, 1)), (1, 1, tf.reshape(values, (batch_size, padded_size, 1)), (1, 1, n_embeddings))
n_embeddings)) emb_weights = tf.tile(
emb_weights = array_ops.tile( tf.reshape(mask, (batch_size, padded_size, 1)), (1, 1, n_embeddings))
array_ops.reshape(mask, (batch_size, padded_size, 1)), col_idcs = tf.tile(
(1, 1, n_embeddings)) tf.reshape(tf.range(n_embeddings), (1, 1, n_embeddings)),
col_idcs = array_ops.tile(
array_ops.reshape(math_ops.range(n_embeddings), (1, 1, n_embeddings)),
(batch_size, padded_size, 1)) (batch_size, padded_size, 1))
one_hot = array_ops.where( one_hot = tf.where(
math_ops.equal(emb_idcs, col_idcs), emb_weights, tf.equal(emb_idcs, col_idcs), emb_weights,
array_ops.zeros((batch_size, padded_size, n_embeddings))) tf.zeros((batch_size, padded_size, n_embeddings)))
return math_ops.tensordot(one_hot, embedding_table, 1) return tf.tensordot(one_hot, embedding_table, 1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment