Commit 981a5c8c authored by thomwolf's avatar thomwolf
Browse files

updating models urls

parent 8ae1044f
......@@ -28,6 +28,10 @@ logger = logging.getLogger(__name__)
T5_PRETRAINED_CONFIG_ARCHIVE_MAP = {
't5-small': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-config.json",
't5-base': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-config.json",
't5-large': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-config.json",
't5-3B': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-3B-config.json",
't5-11B': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-11B-config.json",
}
......
......@@ -121,7 +121,7 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
if compare_with_pt_model:
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
tf_inputs = tf.constant(inputs_list)
tf_inputs = tf_model.dummy_inputs
tfo = tf_model(tf_inputs, training=False) # build the network
pt_model = pt_model_class.from_pretrained(None,
......
......@@ -42,6 +42,10 @@ logger = logging.getLogger(__name__)
####################################################
T5_PRETRAINED_MODEL_ARCHIVE_MAP = {
't5-small': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-pytorch_model.bin",
't5-base': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-pytorch_model.bin",
't5-large': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-pytorch_model.bin",
't5-3B': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-3B-pytorch_model.bin",
't5-11B': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-11B-pytorch_model.bin",
}
####################################################
......
......@@ -25,13 +25,17 @@ import itertools
import tensorflow as tf
from .configuration_t5 import T5Config
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list, get_initializer, DUMMY_INPUTS
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list
from .file_utils import add_start_docstrings
logger = logging.getLogger(__name__)
TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP = {
't5-small': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-tf_model.h5",
't5-base': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-tf_model.h5",
't5-large': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-tf_model.h5",
't5-3B': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-3B-tf_model.h5",
't5-11B': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-11B-tf_model.h5",
}
####################################################
......
......@@ -41,7 +41,11 @@ VOCAB_FILES_NAMES = {'vocab_file': 'spiece.model'}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
't5': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
't5-small': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
't5-base': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
't5-large': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
't5-3B': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
't5-11B': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
}
}
......@@ -49,7 +53,11 @@ PRETRAINED_VOCAB_FILES_MAP = {
# Mapping from model shortcut names to max length of inputs
####################################################
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
't5': 512,
't5-small': 512,
't5-base': 512,
't5-large': 512,
't5-3B': 512,
't5-11B': 512,
}
class T5Tokenizer(PreTrainedTokenizer):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment