Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
366a3b02
Commit
366a3b02
authored
May 08, 2019
by
thomwolf
Browse files
clean up in tokenization
parent
0efc4ab6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
4 deletions
+13
-4
pytorch_pretrained_bert/modeling_gpt2.py
pytorch_pretrained_bert/modeling_gpt2.py
+4
-2
pytorch_pretrained_bert/tokenization_gpt2.py
pytorch_pretrained_bert/tokenization_gpt2.py
+8
-1
pytorch_pretrained_bert/tokenization_openai.py
pytorch_pretrained_bert/tokenization_openai.py
+1
-1
No files found.
pytorch_pretrained_bert/modeling_gpt2.py
View file @
366a3b02
...
@@ -39,8 +39,10 @@ from .modeling import BertLayerNorm as LayerNorm
...
@@ -39,8 +39,10 @@ from .modeling import BertLayerNorm as LayerNorm
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_MODEL_ARCHIVE_MAP
=
{
"gpt2"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"
}
PRETRAINED_MODEL_ARCHIVE_MAP
=
{
"gpt2"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"
,
PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
"gpt2"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"
}
"gpt2-medium"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin"
}
PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
"gpt2"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"
,
"gpt2-medium"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"
}
def
load_tf_weights_in_gpt2
(
model
,
gpt2_checkpoint_path
):
def
load_tf_weights_in_gpt2
(
model
,
gpt2_checkpoint_path
):
""" Load tf checkpoints in a pytorch model
""" Load tf checkpoints in a pytorch model
...
...
pytorch_pretrained_bert/tokenization_gpt2.py
View file @
366a3b02
...
@@ -37,9 +37,11 @@ logger = logging.getLogger(__name__)
...
@@ -37,9 +37,11 @@ logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP
=
{
PRETRAINED_VOCAB_ARCHIVE_MAP
=
{
'gpt2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json"
,
'gpt2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json"
,
'gpt2-medium'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json"
,
}
}
PRETRAINED_MERGES_ARCHIVE_MAP
=
{
PRETRAINED_MERGES_ARCHIVE_MAP
=
{
'gpt2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt"
,
'gpt2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt"
,
'gpt2-medium'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt"
,
}
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
=
{
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
=
{
'gpt2'
:
1024
,
'gpt2'
:
1024
,
...
@@ -263,9 +265,14 @@ class GPT2Tokenizer(object):
...
@@ -263,9 +265,14 @@ class GPT2Tokenizer(object):
def
encode
(
self
,
text
):
def
encode
(
self
,
text
):
return
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
))
return
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
))
def
decode
(
self
,
tokens
,
skip_special_tokens
=
False
):
def
decode
(
self
,
tokens
,
skip_special_tokens
=
False
,
clean_up_tokenization_spaces
=
True
):
text
=
''
.
join
(
self
.
convert_ids_to_tokens
(
tokens
,
skip_special_tokens
=
skip_special_tokens
))
text
=
''
.
join
(
self
.
convert_ids_to_tokens
(
tokens
,
skip_special_tokens
=
skip_special_tokens
))
text
=
bytearray
([
self
.
byte_decoder
[
c
]
for
c
in
text
]).
decode
(
'utf-8'
,
errors
=
self
.
errors
)
text
=
bytearray
([
self
.
byte_decoder
[
c
]
for
c
in
text
]).
decode
(
'utf-8'
,
errors
=
self
.
errors
)
if
clean_up_tokenization_spaces
:
text
=
text
.
replace
(
'<unk>'
,
''
)
text
=
text
.
replace
(
' .'
,
'.'
).
replace
(
' ?'
,
'?'
).
replace
(
' !'
,
'!'
).
replace
(
' ,'
,
','
).
replace
(
" ' "
,
"'"
).
replace
(
" n't"
,
"n't"
).
replace
(
" 'm"
,
"'m"
).
replace
(
" do not"
,
" don't"
).
replace
(
" 's"
,
"'s"
).
replace
(
" 've"
,
"'ve"
).
replace
(
" 're"
,
"'re"
)
return
text
return
text
def
save_vocabulary
(
self
,
vocab_path
):
def
save_vocabulary
(
self
,
vocab_path
):
...
...
pytorch_pretrained_bert/tokenization_openai.py
View file @
366a3b02
...
@@ -272,7 +272,7 @@ class OpenAIGPTTokenizer(object):
...
@@ -272,7 +272,7 @@ class OpenAIGPTTokenizer(object):
out_string
=
''
.
join
(
tokens
).
replace
(
'</w>'
,
' '
).
strip
()
out_string
=
''
.
join
(
tokens
).
replace
(
'</w>'
,
' '
).
strip
()
if
clean_up_tokenization_spaces
:
if
clean_up_tokenization_spaces
:
out_string
=
out_string
.
replace
(
'<unk>'
,
''
)
out_string
=
out_string
.
replace
(
'<unk>'
,
''
)
out_string
=
out_string
.
replace
(
' .'
,
'.'
).
replace
(
' ?'
,
'?'
).
replace
(
' !'
,
'!'
).
replace
(
' ,'
,
','
).
replace
(
' ,'
,
','
out_string
=
out_string
.
replace
(
' .'
,
'.'
).
replace
(
' ?'
,
'?'
).
replace
(
' !'
,
'!'
).
replace
(
' ,'
,
','
).
replace
(
" ' "
,
"'"
).
replace
(
" n't"
,
"n't"
).
replace
(
" 'm"
,
"'m"
).
replace
(
" do not"
,
" don't"
).
replace
(
" ' "
,
"'"
).
replace
(
" n't"
,
"n't"
).
replace
(
" 'm"
,
"'m"
).
replace
(
" do not"
,
" don't"
).
replace
(
" 's"
,
"'s"
).
replace
(
" 've"
,
"'ve"
).
replace
(
" 're"
,
"'re"
)
).
replace
(
" 's"
,
"'s"
).
replace
(
" 've"
,
"'ve"
).
replace
(
" 're"
,
"'re"
)
return
out_string
return
out_string
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment