"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "cafa6a9e29f3e99c67a1028f8ca779d439bc0689"
Commit cb9db101 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

Python 2 must DIE

parent 05c08352
...@@ -58,7 +58,7 @@ class RobertaEmbeddings(BertEmbeddings): ...@@ -58,7 +58,7 @@ class RobertaEmbeddings(BertEmbeddings):
# cf. fairseq's `utils.make_positions` # cf. fairseq's `utils.make_positions`
position_ids = torch.arange(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=torch.long, device=input_ids.device) position_ids = torch.arange(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
return super().forward(input_ids, token_type_ids=token_type_ids, position_ids=position_ids) return super(RobertaEmbeddings, self).forward(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
class RobertaConfig(BertConfig): class RobertaConfig(BertConfig):
...@@ -109,8 +109,8 @@ class RobertaForMaskedLM(BertPreTrainedModel): ...@@ -109,8 +109,8 @@ class RobertaForMaskedLM(BertPreTrainedModel):
class RobertaLMHead(nn.Module): class RobertaLMHead(nn.Module):
"""Roberta Head for masked language modeling.""" """Roberta Head for masked language modeling."""
def __init__(self, config: BertConfig): def __init__(self, config):
super().__init__() super(RobertaLMHead, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.layer_norm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layer_norm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
......
...@@ -18,6 +18,7 @@ from __future__ import (absolute_import, division, print_function, ...@@ -18,6 +18,7 @@ from __future__ import (absolute_import, division, print_function,
import os import os
import unittest import unittest
import pytest import pytest
import six
from pytorch_transformers.tokenization_roberta import RobertaTokenizer from pytorch_transformers.tokenization_roberta import RobertaTokenizer
...@@ -31,10 +32,11 @@ class RobertaTokenizationTest(unittest.TestCase): ...@@ -31,10 +32,11 @@ class RobertaTokenizationTest(unittest.TestCase):
tokenizer.encode('Hello world!'), tokenizer.encode('Hello world!'),
[0, 31414, 232, 328, 2] [0, 31414, 232, 328, 2]
) )
self.assertListEqual( if six.PY3:
tokenizer.encode('Hello world! cécé herlolip'), self.assertListEqual(
[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2] tokenizer.encode('Hello world! cécé herlolip'),
) [0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]
)
......
...@@ -19,6 +19,8 @@ from __future__ import (absolute_import, division, print_function, ...@@ -19,6 +19,8 @@ from __future__ import (absolute_import, division, print_function,
import json import json
import logging import logging
import re import re
from io import open
import six
from .tokenization_utils import PreTrainedTokenizer from .tokenization_utils import PreTrainedTokenizer
from .tokenization_gpt2 import GPT2Tokenizer from .tokenization_gpt2 import GPT2Tokenizer
...@@ -125,7 +127,7 @@ class Dictionary(object): ...@@ -125,7 +127,7 @@ class Dictionary(object):
Loads a pre-existing dictionary from a text file and adds its symbols Loads a pre-existing dictionary from a text file and adds its symbols
to this instance. to this instance.
""" """
if isinstance(f, str): if isinstance(f, six.string_types):
try: try:
if not ignore_utf_errors: if not ignore_utf_errors:
with open(f, 'r', encoding='utf-8') as fd: with open(f, 'r', encoding='utf-8') as fd:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment