Unverified Commit f9fd7fd7 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[BugFix] extract gz into target dir (#3389)

parent e234fcfa
......@@ -219,7 +219,8 @@ def extract_archive(file, target_dir, overwrite=False):
import gzip
import shutil
with gzip.open(file, 'rb') as f_in:
with open(file[:-3], 'wb') as f_out:
target_file = os.path.join(target_dir, os.path.basename(file)[:-3])
with open(target_file, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
elif file.endswith('.zip'):
import zipfile
......
......@@ -2,6 +2,9 @@ import dgl.data as data
import unittest
import backend as F
import numpy as np
import gzip
import tempfile
import os
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
......@@ -139,6 +142,20 @@ def test_reddit():
assert np.array_equal(dst, np.sort(dst))
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_extract_archive():
# gzip
with tempfile.TemporaryDirectory() as src_dir:
gz_file = 'gz_archive'
gz_path = os.path.join(src_dir, gz_file + '.gz')
content = b"test extract archive gzip"
with gzip.open(gz_path, 'wb') as f:
f.write(content)
with tempfile.TemporaryDirectory() as dst_dir:
data.utils.extract_archive(gz_path, dst_dir, overwrite=True)
assert os.path.exists(os.path.join(dst_dir, gz_file))
if __name__ == '__main__':
test_minigc()
test_gin()
......@@ -146,3 +163,4 @@ if __name__ == '__main__':
test_tudataset_regression()
test_fraud()
test_fakenews()
test_extract_archive()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment