Commit 9761aa48 authored by thomwolf's avatar thomwolf
Browse files

add to_json_file method to configuration classes

parent b17963d8
...@@ -220,6 +220,11 @@ class BertConfig(object): ...@@ -220,6 +220,11 @@ class BertConfig(object):
"""Serializes this instance to a JSON string.""" """Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path):
""" Save this instance to a json file."""
with open(json_file_path, "w", encoding='utf-8') as writer:
writer.write(self.to_json_string())
try: try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
except ImportError: except ImportError:
......
...@@ -180,6 +180,11 @@ class GPT2Config(object): ...@@ -180,6 +180,11 @@ class GPT2Config(object):
"""Serializes this instance to a JSON string.""" """Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path):
""" Save this instance to a json file."""
with open(json_file_path, "w", encoding='utf-8') as writer:
writer.write(self.to_json_string())
class Conv1D(nn.Module): class Conv1D(nn.Module):
def __init__(self, nf, nx): def __init__(self, nf, nx):
......
...@@ -225,6 +225,11 @@ class OpenAIGPTConfig(object): ...@@ -225,6 +225,11 @@ class OpenAIGPTConfig(object):
"""Serializes this instance to a JSON string.""" """Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path):
""" Save this instance to a json file."""
with open(json_file_path, "w", encoding='utf-8') as writer:
writer.write(self.to_json_string())
class Conv1D(nn.Module): class Conv1D(nn.Module):
def __init__(self, nf, rf, nx): def __init__(self, nf, rf, nx):
......
...@@ -316,6 +316,11 @@ class TransfoXLConfig(object): ...@@ -316,6 +316,11 @@ class TransfoXLConfig(object):
"""Serializes this instance to a JSON string.""" """Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path):
""" Save this instance to a json file."""
with open(json_file_path, "w", encoding='utf-8') as writer:
writer.write(self.to_json_string())
class PositionalEmbedding(nn.Module): class PositionalEmbedding(nn.Module):
def __init__(self, demb): def __init__(self, demb):
......
...@@ -16,6 +16,7 @@ from __future__ import absolute_import ...@@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import unittest import unittest
import json import json
import random import random
...@@ -176,6 +177,14 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -176,6 +177,14 @@ class GPT2ModelTest(unittest.TestCase):
self.assertEqual(obj["vocab_size"], 99) self.assertEqual(obj["vocab_size"], 99)
self.assertEqual(obj["n_embd"], 37) self.assertEqual(obj["n_embd"], 37)
def test_config_to_json_file(self):
config_first = GPT2Config(vocab_size_or_config_json_file=99, n_embd=37)
json_file_path = "/tmp/config.json"
config_first.to_json_file(json_file_path)
config_second = GPT2Config.from_json_file(json_file_path)
os.remove(json_file_path)
self.assertEqual(config_second.to_dict(), config_first.to_dict())
def run_tester(self, tester): def run_tester(self, tester):
config_and_inputs = tester.prepare_config_and_inputs() config_and_inputs = tester.prepare_config_and_inputs()
output_result = tester.create_gpt2_model(*config_and_inputs) output_result = tester.create_gpt2_model(*config_and_inputs)
......
...@@ -16,6 +16,7 @@ from __future__ import absolute_import ...@@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import unittest import unittest
import json import json
import random import random
...@@ -188,6 +189,14 @@ class OpenAIGPTModelTest(unittest.TestCase): ...@@ -188,6 +189,14 @@ class OpenAIGPTModelTest(unittest.TestCase):
self.assertEqual(obj["vocab_size"], 99) self.assertEqual(obj["vocab_size"], 99)
self.assertEqual(obj["n_embd"], 37) self.assertEqual(obj["n_embd"], 37)
def test_config_to_json_file(self):
config_first = OpenAIGPTConfig(vocab_size_or_config_json_file=99, n_embd=37)
json_file_path = "/tmp/config.json"
config_first.to_json_file(json_file_path)
config_second = OpenAIGPTConfig.from_json_file(json_file_path)
os.remove(json_file_path)
self.assertEqual(config_second.to_dict(), config_first.to_dict())
def run_tester(self, tester): def run_tester(self, tester):
config_and_inputs = tester.prepare_config_and_inputs() config_and_inputs = tester.prepare_config_and_inputs()
output_result = tester.create_openai_model(*config_and_inputs) output_result = tester.create_openai_model(*config_and_inputs)
......
...@@ -16,6 +16,7 @@ from __future__ import absolute_import ...@@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import unittest import unittest
import json import json
import random import random
...@@ -251,6 +252,14 @@ class BertModelTest(unittest.TestCase): ...@@ -251,6 +252,14 @@ class BertModelTest(unittest.TestCase):
self.assertEqual(obj["vocab_size"], 99) self.assertEqual(obj["vocab_size"], 99)
self.assertEqual(obj["hidden_size"], 37) self.assertEqual(obj["hidden_size"], 37)
def test_config_to_json_file(self):
config_first = BertConfig(vocab_size_or_config_json_file=99, hidden_size=37)
json_file_path = "/tmp/config.json"
config_first.to_json_file(json_file_path)
config_second = BertConfig.from_json_file(json_file_path)
os.remove(json_file_path)
self.assertEqual(config_second.to_dict(), config_first.to_dict())
def run_tester(self, tester): def run_tester(self, tester):
config_and_inputs = tester.prepare_config_and_inputs() config_and_inputs = tester.prepare_config_and_inputs()
output_result = tester.create_bert_model(*config_and_inputs) output_result = tester.create_bert_model(*config_and_inputs)
......
...@@ -16,6 +16,7 @@ from __future__ import absolute_import ...@@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import unittest import unittest
import json import json
import random import random
...@@ -186,6 +187,14 @@ class TransfoXLModelTest(unittest.TestCase): ...@@ -186,6 +187,14 @@ class TransfoXLModelTest(unittest.TestCase):
self.assertEqual(obj["n_token"], 96) self.assertEqual(obj["n_token"], 96)
self.assertEqual(obj["d_embed"], 37) self.assertEqual(obj["d_embed"], 37)
def test_config_to_json_file(self):
config_first = TransfoXLConfig(vocab_size_or_config_json_file=96, d_embed=37)
json_file_path = "/tmp/config.json"
config_first.to_json_file(json_file_path)
config_second = TransfoXLConfig.from_json_file(json_file_path)
os.remove(json_file_path)
self.assertEqual(config_second.to_dict(), config_first.to_dict())
def run_tester(self, tester): def run_tester(self, tester):
config_and_inputs = tester.prepare_config_and_inputs() config_and_inputs = tester.prepare_config_and_inputs()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment