Commit a15ff20f authored by Ardalan's avatar Ardalan Committed by Francisco Massa
Browse files

add tar.xz archive handler (#1361)

* add tar.xz archive handler

* update unittest for tar.xz archive

* remove .tar.xz unittest

* add separate .tar.xz unittest

* update PY2 compatibility
parent 1909495a
......@@ -100,6 +100,23 @@ class Tester(unittest.TestCase):
data = nf.read()
self.assertEqual(data, 'this is the content')
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
@unittest.skipIf(sys.version_info < (3,), "Extracting .tar.xz files is not supported under Python 2.x")
def test_extract_tar_xz(self):
for ext, mode in zip(['.tar.xz'], ['w:xz']):
with get_tmp_dir() as temp_dir:
with tempfile.NamedTemporaryFile() as bf:
bf.write("this is the content".encode())
bf.seek(0)
with tempfile.NamedTemporaryFile(suffix=ext) as f:
with tarfile.open(f.name, mode=mode) as zf:
zf.add(bf.name, arcname='file.tst')
utils.extract_archive(f.name, temp_dir)
self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst')))
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
data = nf.read()
self.assertEqual(data, 'this is the content')
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_extract_gzip(self):
with get_tmp_dir() as temp_dir:
......
......@@ -8,6 +8,7 @@ import zipfile
import torch
from torch.utils.model_zoo import tqdm
from torch._six import PY3
def gen_bar_updater():
......@@ -197,6 +198,10 @@ def _save_response_content(response, destination, chunk_size=32768):
pbar.close()
def _is_tarxz(filename):
return filename.endswith(".tar.xz")
def _is_tar(filename):
return filename.endswith(".tar")
......@@ -223,6 +228,10 @@ def extract_archive(from_path, to_path=None, remove_finished=False):
elif _is_targz(from_path):
with tarfile.open(from_path, 'r:gz') as tar:
tar.extractall(path=to_path)
elif _is_tarxz(from_path) and PY3:
# .tar.xz archive only supported in Python 3.x
with tarfile.open(from_path, 'r:xz') as tar:
tar.extractall(path=to_path)
elif _is_gzip(from_path):
to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment