Unverified Commit 562f8640 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge branch 'master' into fix-xlnet-squad2.0

parents ca99a2d5 8618bf15
This diff is collapsed.
# progress bars in model download and training scripts
tqdm
# Accessing files from S3 directly.
boto3
# Used for downloading models over HTTP
requests
# For ROUGE
nltk
py-rouge
This diff is collapsed.
...@@ -21,7 +21,6 @@ from utils_summarization import ( ...@@ -21,7 +21,6 @@ from utils_summarization import (
compute_token_type_ids, compute_token_type_ids,
fit_to_block_size, fit_to_block_size,
build_mask, build_mask,
build_lm_labels,
process_story, process_story,
) )
...@@ -88,20 +87,6 @@ class SummarizationDataProcessingTest(unittest.TestCase): ...@@ -88,20 +87,6 @@ class SummarizationDataProcessingTest(unittest.TestCase):
expected_summary_lines = ["It was the best of times."] expected_summary_lines = ["It was the best of times."]
self.assertEqual(expected_summary_lines, summary_lines) self.assertEqual(expected_summary_lines, summary_lines)
def test_build_lm_labels_no_padding(self):
sequence = torch.tensor([1, 2, 3, 4])
expected = sequence
np.testing.assert_array_equal(
build_lm_labels(sequence, 0).numpy(), expected.numpy()
)
def test_build_lm_labels(self):
sequence = torch.tensor([1, 2, 3, 4, 0, 0, 0])
expected = torch.tensor([1, 2, 3, 4, -1, -1, -1])
np.testing.assert_array_equal(
build_lm_labels(sequence, 0).numpy(), expected.numpy()
)
def test_build_mask_no_padding(self): def test_build_mask_no_padding(self):
sequence = torch.tensor([1, 2, 3, 4]) sequence = torch.tensor([1, 2, 3, 4])
expected = torch.tensor([1, 1, 1, 1]) expected = torch.tensor([1, 1, 1, 1])
...@@ -125,7 +110,7 @@ class SummarizationDataProcessingTest(unittest.TestCase): ...@@ -125,7 +110,7 @@ class SummarizationDataProcessingTest(unittest.TestCase):
[[1, 2, 3, 4, 5, 6], [1, 2, 3, 101, 5, 6], [1, 101, 3, 4, 101, 6]] [[1, 2, 3, 4, 5, 6], [1, 2, 3, 101, 5, 6], [1, 101, 3, 4, 101, 6]]
) )
expected = torch.tensor( expected = torch.tensor(
[[0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 1, 1], [0, 1, 1, 1, 0, 0]] [[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 1, 1]]
) )
result = compute_token_type_ids(batch, separator) result = compute_token_type_ids(batch, separator)
......
...@@ -72,8 +72,7 @@ class ExamplesTests(unittest.TestCase): ...@@ -72,8 +72,7 @@ class ExamplesTests(unittest.TestCase):
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
testargs = ["run_squad.py", testargs = ["run_squad.py",
"--train_file=./examples/tests_samples/SQUAD/dev-v2.0-small.json", "--data_dir=./examples/tests_samples/SQUAD",
"--predict_file=./examples/tests_samples/SQUAD/dev-v2.0-small.json",
"--model_name=bert-base-uncased", "--model_name=bert-base-uncased",
"--output_dir=./examples/tests_samples/temp_dir", "--output_dir=./examples/tests_samples/temp_dir",
"--max_steps=10", "--max_steps=10",
......
This diff is collapsed.
This diff is collapsed.
...@@ -5,7 +5,7 @@ boto3 ...@@ -5,7 +5,7 @@ boto3
# Used for downloading models over HTTP # Used for downloading models over HTTP
requests requests
# For OpenAI GPT # For OpenAI GPT
regex regex != 2019.12.17
# For XLNet # For XLNet
sentencepiece sentencepiece
# For XLM # For XLM
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment