Unverified Commit 491c19fe authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Bugfix] Fix the call of tar.extractall() method (#5139)

parent 46455328
...@@ -257,19 +257,22 @@ def extract_archive(file, target_dir, overwrite=False): ...@@ -257,19 +257,22 @@ def extract_archive(file, target_dir, overwrite=False):
import tarfile import tarfile
with tarfile.open(file, "r") as archive: with tarfile.open(file, "r") as archive:
def is_within_directory(directory, target): def is_within_directory(directory, target):
abs_directory = os.path.abspath(directory) abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target) abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target]) prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory return prefix == abs_directory
def safe_extract(tar, path=".", members=None, *, numeric_owner=False): def safe_extract(
tar, path=".", members=None, *, numeric_owner=False
):
for member in tar.getmembers(): for member in tar.getmembers():
member_path = os.path.join(path, member.name) member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path): if not is_within_directory(path, member_path):
raise Exception("Attempted Path Traversal in Tar File") raise Exception("Attempted Path Traversal in Tar File")
tar.extractall(path, members, numeric_owner) tar.extractall(path, members, numeric_owner=numeric_owner)
safe_extract(archive, path=target_dir) safe_extract(archive, path=target_dir)
elif file.endswith(".gz"): elif file.endswith(".gz"):
import gzip import gzip
......
import gzip import gzip
import io
import os import os
import tarfile
import tempfile import tempfile
import unittest import unittest
import backend as F import backend as F
import dgl
import dgl.data as data
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import pytest import pytest
import yaml import yaml
import dgl
import dgl.data as data
from dgl import DGLError from dgl import DGLError
...@@ -379,6 +381,20 @@ def test_extract_archive(): ...@@ -379,6 +381,20 @@ def test_extract_archive():
data.utils.extract_archive(gz_path, dst_dir, overwrite=True) data.utils.extract_archive(gz_path, dst_dir, overwrite=True)
assert os.path.exists(os.path.join(dst_dir, gz_file)) assert os.path.exists(os.path.join(dst_dir, gz_file))
# tar
with tempfile.TemporaryDirectory() as src_dir:
tar_file = "tar_archive"
tar_path = os.path.join(src_dir, tar_file + ".tar")
# default encode to utf8
content = "test extract archive tar\n".encode()
info = tarfile.TarInfo(name="tar_archive")
info.size = len(content)
with tarfile.open(tar_path, "w") as f:
f.addfile(info, io.BytesIO(content))
with tempfile.TemporaryDirectory() as dst_dir:
data.utils.extract_archive(tar_path, dst_dir, overwrite=True)
assert os.path.exists(os.path.join(dst_dir, tar_file))
def _test_construct_graphs_node_ids(): def _test_construct_graphs_node_ids():
from dgl.data.csv_dataset_base import ( from dgl.data.csv_dataset_base import (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment