Commit 37df862e authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add missing dependencies to hubconf

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/799

Differential Revision: D15773932

Pulled By: myleott

fbshipit-source-id: 650c0621bedb3b7ecebc0654d8e10d7692c50994
parent 5bdee18e
...@@ -37,10 +37,9 @@ class GPT2BPE(object): ...@@ -37,10 +37,9 @@ class GPT2BPE(object):
"""Byte pair encoding utilities from GPT-2""" """Byte pair encoding utilities from GPT-2"""
import os
import json
import regex as re
from functools import lru_cache from functools import lru_cache
import json
import os
@lru_cache() @lru_cache()
...@@ -77,6 +76,7 @@ def get_pairs(word): ...@@ -77,6 +76,7 @@ def get_pairs(word):
return pairs return pairs
class Encoder: class Encoder:
def __init__(self, encoder, bpe_merges, errors='replace'): def __init__(self, encoder, bpe_merges, errors='replace'):
self.encoder = encoder self.encoder = encoder
self.decoder = {v:k for k,v in self.encoder.items()} self.decoder = {v:k for k,v in self.encoder.items()}
...@@ -86,8 +86,14 @@ class Encoder: ...@@ -86,8 +86,14 @@ class Encoder:
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {} self.cache = {}
try:
import regex as re
self.re = re
except ImportError:
raise ImportError('Please install regex with: pip install regex')
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") self.pat = self.re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
def bpe(self, token): def bpe(self, token):
if token in self.cache: if token in self.cache:
...@@ -132,7 +138,7 @@ class Encoder: ...@@ -132,7 +138,7 @@ class Encoder:
def encode(self, text): def encode(self, text):
bpe_tokens = [] bpe_tokens = []
for token in re.findall(self.pat, text): for token in self.re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
return bpe_tokens return bpe_tokens
......
...@@ -8,7 +8,13 @@ ...@@ -8,7 +8,13 @@
from fairseq.models import MODEL_REGISTRY from fairseq.models import MODEL_REGISTRY
dependencies = ['torch'] dependencies = [
'regex',
'sacremoses',
'sentencepiece',
'subword_nmt',
'torch',
]
for model, cls in MODEL_REGISTRY.items(): for model, cls in MODEL_REGISTRY.items():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment