Commit 5a898973 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 361957289
parent 7e3d5270
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Defines the Transformer model in TF 2.0. """Defines the Transformer model in TF 2.0.
Model paper: https://arxiv.org/pdf/1706.03762.pdf Model paper: https://arxiv.org/pdf/1706.03762.pdf
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Forward pass test for Transformer model refactoring.""" """Forward pass test for Transformer model refactoring."""
import numpy as np import numpy as np
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Tests for layers in Transformer.""" """Tests for layers in Transformer."""
import tensorflow as tf import tensorflow as tf
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Train and evaluate the Transformer model. """Train and evaluate the Transformer model.
See README for description of setting the training schedule and evaluating the See README for description of setting the training schedule and evaluating the
...@@ -38,6 +38,7 @@ from official.nlp.transformer import translate ...@@ -38,6 +38,7 @@ from official.nlp.transformer import translate
from official.nlp.transformer.utils import tokenizer from official.nlp.transformer.utils import tokenizer
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
# pylint:disable=logging-format-interpolation
INF = int(1e9) INF = int(1e9)
BLEU_DIR = "bleu" BLEU_DIR = "bleu"
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Test Transformer model.""" """Test Transformer model."""
import os import os
...@@ -25,7 +25,6 @@ import tensorflow as tf ...@@ -25,7 +25,6 @@ import tensorflow as tf
from tensorflow.python.eager import context # pylint: disable=ungrouped-imports from tensorflow.python.eager import context # pylint: disable=ungrouped-imports
from official.nlp.transformer import misc from official.nlp.transformer import misc
from official.nlp.transformer import transformer_main from official.nlp.transformer import transformer_main
from official.utils.misc import keras_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
FIXED_TIMESTAMP = 'my_time_stamp' FIXED_TIMESTAMP = 'my_time_stamp'
...@@ -41,7 +40,7 @@ def _generate_file(filepath, lines): ...@@ -41,7 +40,7 @@ def _generate_file(filepath, lines):
class TransformerTaskTest(tf.test.TestCase): class TransformerTaskTest(tf.test.TestCase):
local_flags = None local_flags = None
def setUp(self): def setUp(self): # pylint: disable=g-missing-super-call
temp_dir = self.get_temp_dir() temp_dir = self.get_temp_dir()
if TransformerTaskTest.local_flags is None: if TransformerTaskTest.local_flags is None:
misc.define_transformer_flags() misc.define_transformer_flags()
...@@ -70,7 +69,7 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -70,7 +69,7 @@ class TransformerTaskTest(tf.test.TestCase):
self.orig_policy = ( self.orig_policy = (
tf.compat.v2.keras.mixed_precision.experimental.global_policy()) tf.compat.v2.keras.mixed_precision.experimental.global_policy())
def tearDown(self): def tearDown(self): # pylint: disable=g-missing-super-call
tf.compat.v2.keras.mixed_precision.experimental.set_policy(self.orig_policy) tf.compat.v2.keras.mixed_precision.experimental.set_policy(self.orig_policy)
def _assert_exists(self, filepath): def _assert_exists(self, filepath):
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Test Transformer model.""" """Test Transformer model."""
import tensorflow as tf import tensorflow as tf
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Translate text or files using trained transformer model.""" """Translate text or files using trained transformer model."""
# Import libraries # Import libraries
...@@ -146,7 +146,7 @@ def translate_file(model, ...@@ -146,7 +146,7 @@ def translate_file(model,
def text_as_per_replica(): def text_as_per_replica():
replica_context = tf.distribute.get_replica_context() replica_context = tf.distribute.get_replica_context()
replica_id = replica_context.replica_id_in_sync_group replica_id = replica_context.replica_id_in_sync_group
return replica_id, text[replica_id] return replica_id, text[replica_id] # pylint: disable=cell-var-from-loop
text = distribution_strategy.run(text_as_per_replica) text = distribution_strategy.run(text_as_per_replica)
outputs = distribution_strategy.experimental_local_results( outputs = distribution_strategy.experimental_local_results(
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the 'License'); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Functions for calculating loss, accuracy, and other model metrics. """Functions for calculating loss, accuracy, and other model metrics.
Metrics: Metrics:
...@@ -203,7 +203,7 @@ def bleu_score(logits, labels): ...@@ -203,7 +203,7 @@ def bleu_score(logits, labels):
bleu: int, approx bleu score bleu: int, approx bleu score
""" """
predictions = tf.cast(tf.argmax(logits, axis=-1), tf.int32) predictions = tf.cast(tf.argmax(logits, axis=-1), tf.int32)
# TODO: Look into removing use of py_func # TODO: Look into removing use of py_func # pylint: disable=g-bad-todo
bleu = tf.py_func(compute_bleu, (labels, predictions), tf.float32) bleu = tf.py_func(compute_bleu, (labels, predictions), tf.float32)
return bleu, tf.constant(1.0) return bleu, tf.constant(1.0)
...@@ -308,7 +308,7 @@ def rouge_2_fscore(logits, labels): ...@@ -308,7 +308,7 @@ def rouge_2_fscore(logits, labels):
rouge2_fscore: approx rouge-2 f1 score. rouge2_fscore: approx rouge-2 f1 score.
""" """
predictions = tf.cast(tf.argmax(logits, axis=-1), tf.int32) predictions = tf.cast(tf.argmax(logits, axis=-1), tf.int32)
# TODO: Look into removing use of py_func # TODO: Look into removing use of py_func # pylint: disable=g-bad-todo
rouge_2_f_score = tf.py_func(rouge_n, (predictions, labels), tf.float32) rouge_2_f_score = tf.py_func(rouge_n, (predictions, labels), tf.float32)
return rouge_2_f_score, tf.constant(1.0) return rouge_2_f_score, tf.constant(1.0)
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Defines Subtokenizer class to encode and decode strings.""" """Defines Subtokenizer class to encode and decode strings."""
from __future__ import absolute_import from __future__ import absolute_import
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Test Subtokenizer and string helper methods.""" """Test Subtokenizer and string helper methods."""
import collections import collections
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Utilities for pre-processing classification data.""" """Utilities for pre-processing classification data."""
from absl import logging from absl import logging
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Common flags used in XLNet model.""" """Common flags used in XLNet model."""
from absl import flags from absl import flags
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Utilities used for data preparation.""" """Utilities used for data preparation."""
import collections import collections
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Functions and classes related to optimization (weight updates).""" """Functions and classes related to optimization (weight updates)."""
from absl import logging from absl import logging
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Script to pre-process classification data into tfrecords.""" """Script to pre-process classification data into tfrecords."""
import collections import collections
......
# -*- coding: utf-8 -*- # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +11,8 @@ ...@@ -12,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
"""Script to pre-process pre-training data into tfrecords.""" """Script to pre-process pre-training data into tfrecords."""
import json import json
...@@ -34,15 +34,15 @@ FLAGS = flags.FLAGS ...@@ -34,15 +34,15 @@ FLAGS = flags.FLAGS
special_symbols = { special_symbols = {
"<unk>" : 0, "<unk>": 0,
"<s>" : 1, "<s>": 1,
"</s>" : 2, "</s>": 2,
"<cls>" : 3, "<cls>": 3,
"<sep>" : 4, "<sep>": 4,
"<pad>" : 5, "<pad>": 5,
"<mask>" : 6, "<mask>": 6,
"<eod>" : 7, "<eod>": 7,
"<eop>" : 8, "<eop>": 8,
} }
VOCAB_SIZE = 32000 VOCAB_SIZE = 32000
...@@ -627,6 +627,7 @@ def _convert_example(example, use_bfloat16): ...@@ -627,6 +627,7 @@ def _convert_example(example, use_bfloat16):
def parse_files_to_dataset(parser, file_names, split, num_batch, num_hosts, def parse_files_to_dataset(parser, file_names, split, num_batch, num_hosts,
host_id, num_core_per_host, bsz_per_core): host_id, num_core_per_host, bsz_per_core):
"""Parses files to a dataset.""" """Parses files to a dataset."""
del num_batch
# list of file pathes # list of file pathes
num_files = len(file_names) num_files = len(file_names)
num_files_per_host = num_files // num_hosts num_files_per_host = num_files // num_hosts
...@@ -733,6 +734,8 @@ def get_dataset(params, num_hosts, num_core_per_host, split, file_names, ...@@ -733,6 +734,8 @@ def get_dataset(params, num_hosts, num_core_per_host, split, file_names,
mask_beta, use_bfloat16=False, num_predict=None): mask_beta, use_bfloat16=False, num_predict=None):
"""Gets the dataset.""" """Gets the dataset."""
del mask_alpha
del mask_beta
bsz_per_core = params["batch_size"] bsz_per_core = params["batch_size"]
if num_hosts > 1: if num_hosts > 1:
host_id = params["context"].current_host host_id = params["context"].current_host
......
# coding=utf-8 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +11,8 @@ ...@@ -12,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
# coding=utf-8
"""Script to pre-process SQUAD data into tfrecords.""" """Script to pre-process SQUAD data into tfrecords."""
import os import os
......
# coding=utf-8 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +11,8 @@ ...@@ -12,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
# coding=utf-8
"""Utilities for pre-processing.""" """Utilities for pre-processing."""
import unicodedata import unicodedata
...@@ -36,7 +36,7 @@ def printable_text(text): ...@@ -36,7 +36,7 @@ def printable_text(text):
elif six.PY2: elif six.PY2:
if isinstance(text, str): if isinstance(text, str):
return text return text
elif isinstance(text, unicode): elif isinstance(text, unicode): # pylint: disable=undefined-variable
return text.encode('utf-8') return text.encode('utf-8')
else: else:
raise ValueError('Unsupported string type: %s' % (type(text))) raise ValueError('Unsupported string type: %s' % (type(text)))
...@@ -81,7 +81,7 @@ def encode_pieces(sp_model, text, return_unicode=True, sample=False): ...@@ -81,7 +81,7 @@ def encode_pieces(sp_model, text, return_unicode=True, sample=False):
"""Encodes pieces.""" """Encodes pieces."""
# return_unicode is used only for py2 # return_unicode is used only for py2
if six.PY2 and isinstance(text, unicode): if six.PY2 and isinstance(text, unicode): # pylint: disable=undefined-variable
text = text.encode('utf-8') text = text.encode('utf-8')
if not sample: if not sample:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment