Commit c36d19db authored by mashun1's avatar mashun1
Browse files

liveportrait

parents
Pipeline #1402 canceled with stages
import numpy as np
from numpy.linalg import norm as l2norm
#from easydict import EasyDict
class Face(dict):
def __init__(self, d=None, **kwargs):
if d is None:
d = {}
if kwargs:
d.update(**kwargs)
for k, v in d.items():
setattr(self, k, v)
# Class attributes
#for k in self.__class__.__dict__.keys():
# if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'):
# setattr(self, k, getattr(self, k))
def __setattr__(self, name, value):
if isinstance(value, (list, tuple)):
value = [self.__class__(x)
if isinstance(x, dict) else x for x in value]
elif isinstance(value, dict) and not isinstance(value, self.__class__):
value = self.__class__(value)
super(Face, self).__setattr__(name, value)
super(Face, self).__setitem__(name, value)
__setitem__ = __setattr__
def __getattr__(self, name):
return None
@property
def embedding_norm(self):
if self.embedding is None:
return None
return l2norm(self.embedding)
@property
def normed_embedding(self):
if self.embedding is None:
return None
return self.embedding / self.embedding_norm
@property
def sex(self):
if self.gender is None:
return None
return 'M' if self.gender==1 else 'F'
# -*- coding: utf-8 -*-
# @Organization : insightface.ai
# @Author : Jia Guo
# @Time : 2021-05-04
# @Function :
from __future__ import division
import glob
import os.path as osp
import numpy as np
import onnxruntime
from numpy.linalg import norm
from ..model_zoo import model_zoo
from ..utils import ensure_available
from .common import Face
DEFAULT_MP_NAME = 'buffalo_l'
__all__ = ['FaceAnalysis']
class FaceAnalysis:
def __init__(self, name=DEFAULT_MP_NAME, root='~/.insightface', allowed_modules=None, **kwargs):
onnxruntime.set_default_logger_severity(3)
self.models = {}
self.model_dir = ensure_available('models', name, root=root)
onnx_files = glob.glob(osp.join(self.model_dir, '*.onnx'))
onnx_files = sorted(onnx_files)
for onnx_file in onnx_files:
model = model_zoo.get_model(onnx_file, **kwargs)
if model is None:
print('model not recognized:', onnx_file)
elif allowed_modules is not None and model.taskname not in allowed_modules:
print('model ignore:', onnx_file, model.taskname)
del model
elif model.taskname not in self.models and (allowed_modules is None or model.taskname in allowed_modules):
# print('find model:', onnx_file, model.taskname, model.input_shape, model.input_mean, model.input_std)
self.models[model.taskname] = model
else:
print('duplicated model task type, ignore:', onnx_file, model.taskname)
del model
assert 'detection' in self.models
self.det_model = self.models['detection']
def prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640)):
self.det_thresh = det_thresh
assert det_size is not None
# print('set det-size:', det_size)
self.det_size = det_size
for taskname, model in self.models.items():
if taskname=='detection':
model.prepare(ctx_id, input_size=det_size, det_thresh=det_thresh)
else:
model.prepare(ctx_id)
def get(self, img, max_num=0):
bboxes, kpss = self.det_model.detect(img,
max_num=max_num,
metric='default')
if bboxes.shape[0] == 0:
return []
ret = []
for i in range(bboxes.shape[0]):
bbox = bboxes[i, 0:4]
det_score = bboxes[i, 4]
kps = None
if kpss is not None:
kps = kpss[i]
face = Face(bbox=bbox, kps=kps, det_score=det_score)
for taskname, model in self.models.items():
if taskname=='detection':
continue
model.get(img, face)
ret.append(face)
return ret
def draw_on(self, img, faces):
import cv2
dimg = img.copy()
for i in range(len(faces)):
face = faces[i]
box = face.bbox.astype(np.int)
color = (0, 0, 255)
cv2.rectangle(dimg, (box[0], box[1]), (box[2], box[3]), color, 2)
if face.kps is not None:
kps = face.kps.astype(np.int)
#print(landmark.shape)
for l in range(kps.shape[0]):
color = (0, 0, 255)
if l == 0 or l == 3:
color = (0, 255, 0)
cv2.circle(dimg, (kps[l][0], kps[l][1]), 1, color,
2)
if face.gender is not None and face.age is not None:
cv2.putText(dimg,'%s,%d'%(face.sex,face.age), (box[0]-1, box[1]-4),cv2.FONT_HERSHEY_COMPLEX,0.7,(0,255,0),1)
#for key, value in face.items():
# if key.startswith('landmark_3d'):
# print(key, value.shape)
# print(value[0:10,:])
# lmk = np.round(value).astype(np.int)
# for l in range(lmk.shape[0]):
# color = (255, 0, 0)
# cv2.circle(dimg, (lmk[l][0], lmk[l][1]), 1, color,
# 2)
return dimg
from .image import get_image
from .pickle_object import get_object
import cv2
import os
import os.path as osp
from pathlib import Path
class ImageCache:
data = {}
def get_image(name, to_rgb=False):
key = (name, to_rgb)
if key in ImageCache.data:
return ImageCache.data[key]
images_dir = osp.join(Path(__file__).parent.absolute(), 'images')
ext_names = ['.jpg', '.png', '.jpeg']
image_file = None
for ext_name in ext_names:
_image_file = osp.join(images_dir, "%s%s"%(name, ext_name))
if osp.exists(_image_file):
image_file = _image_file
break
assert image_file is not None, '%s not found'%name
img = cv2.imread(image_file)
if to_rgb:
img = img[:,:,::-1]
ImageCache.data[key] = img
return img
import cv2
import os
import os.path as osp
from pathlib import Path
import pickle
def get_object(name):
objects_dir = osp.join(Path(__file__).parent.absolute(), 'objects')
if not name.endswith('.pkl'):
name = name+".pkl"
filepath = osp.join(objects_dir, name)
if not osp.exists(filepath):
return None
with open(filepath, 'rb') as f:
obj = pickle.load(f)
return obj
import pickle
import numpy as np
import os
import os.path as osp
import sys
import mxnet as mx
class RecBuilder():
def __init__(self, path, image_size=(112, 112)):
self.path = path
self.image_size = image_size
self.widx = 0
self.wlabel = 0
self.max_label = -1
assert not osp.exists(path), '%s exists' % path
os.makedirs(path)
self.writer = mx.recordio.MXIndexedRecordIO(os.path.join(path, 'train.idx'),
os.path.join(path, 'train.rec'),
'w')
self.meta = []
def add(self, imgs):
#!!! img should be BGR!!!!
#assert label >= 0
#assert label > self.last_label
assert len(imgs) > 0
label = self.wlabel
for img in imgs:
idx = self.widx
image_meta = {'image_index': idx, 'image_classes': [label]}
header = mx.recordio.IRHeader(0, label, idx, 0)
if isinstance(img, np.ndarray):
s = mx.recordio.pack_img(header,img,quality=95,img_fmt='.jpg')
else:
s = mx.recordio.pack(header, img)
self.writer.write_idx(idx, s)
self.meta.append(image_meta)
self.widx += 1
self.max_label = label
self.wlabel += 1
def add_image(self, img, label):
#!!! img should be BGR!!!!
#assert label >= 0
#assert label > self.last_label
idx = self.widx
header = mx.recordio.IRHeader(0, label, idx, 0)
if isinstance(label, list):
idlabel = label[0]
else:
idlabel = label
image_meta = {'image_index': idx, 'image_classes': [idlabel]}
if isinstance(img, np.ndarray):
s = mx.recordio.pack_img(header,img,quality=95,img_fmt='.jpg')
else:
s = mx.recordio.pack(header, img)
self.writer.write_idx(idx, s)
self.meta.append(image_meta)
self.widx += 1
self.max_label = max(self.max_label, idlabel)
def close(self):
with open(osp.join(self.path, 'train.meta'), 'wb') as pfile:
pickle.dump(self.meta, pfile, protocol=pickle.HIGHEST_PROTOCOL)
print('stat:', self.widx, self.wlabel)
with open(os.path.join(self.path, 'property'), 'w') as f:
f.write("%d,%d,%d\n" % (self.max_label+1, self.image_size[0], self.image_size[1]))
f.write("%d\n" % (self.widx))
from .model_zoo import get_model
from .arcface_onnx import ArcFaceONNX
from .retinaface import RetinaFace
from .scrfd import SCRFD
from .landmark import Landmark
from .attribute import Attribute
# -*- coding: utf-8 -*-
# @Organization : insightface.ai
# @Author : Jia Guo
# @Time : 2021-05-04
# @Function :
from __future__ import division
import numpy as np
import cv2
import onnx
import onnxruntime
from ..utils import face_align
__all__ = [
'ArcFaceONNX',
]
class ArcFaceONNX:
def __init__(self, model_file=None, session=None):
assert model_file is not None
self.model_file = model_file
self.session = session
self.taskname = 'recognition'
find_sub = False
find_mul = False
model = onnx.load(self.model_file)
graph = model.graph
for nid, node in enumerate(graph.node[:8]):
#print(nid, node.name)
if node.name.startswith('Sub') or node.name.startswith('_minus'):
find_sub = True
if node.name.startswith('Mul') or node.name.startswith('_mul'):
find_mul = True
if find_sub and find_mul:
#mxnet arcface model
input_mean = 0.0
input_std = 1.0
else:
input_mean = 127.5
input_std = 127.5
self.input_mean = input_mean
self.input_std = input_std
#print('input mean and std:', self.input_mean, self.input_std)
if self.session is None:
self.session = onnxruntime.InferenceSession(self.model_file, None)
input_cfg = self.session.get_inputs()[0]
input_shape = input_cfg.shape
input_name = input_cfg.name
self.input_size = tuple(input_shape[2:4][::-1])
self.input_shape = input_shape
outputs = self.session.get_outputs()
output_names = []
for out in outputs:
output_names.append(out.name)
self.input_name = input_name
self.output_names = output_names
assert len(self.output_names)==1
self.output_shape = outputs[0].shape
def prepare(self, ctx_id, **kwargs):
if ctx_id<0:
self.session.set_providers(['CPUExecutionProvider'])
def get(self, img, face):
aimg = face_align.norm_crop(img, landmark=face.kps, image_size=self.input_size[0])
face.embedding = self.get_feat(aimg).flatten()
return face.embedding
def compute_sim(self, feat1, feat2):
from numpy.linalg import norm
feat1 = feat1.ravel()
feat2 = feat2.ravel()
sim = np.dot(feat1, feat2) / (norm(feat1) * norm(feat2))
return sim
def get_feat(self, imgs):
if not isinstance(imgs, list):
imgs = [imgs]
input_size = self.input_size
blob = cv2.dnn.blobFromImages(imgs, 1.0 / self.input_std, input_size,
(self.input_mean, self.input_mean, self.input_mean), swapRB=True)
net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
return net_out
def forward(self, batch_data):
blob = (batch_data - self.input_mean) / self.input_std
net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
return net_out
# -*- coding: utf-8 -*-
# @Organization : insightface.ai
# @Author : Jia Guo
# @Time : 2021-06-19
# @Function :
from __future__ import division
import numpy as np
import cv2
import onnx
import onnxruntime
from ..utils import face_align
__all__ = [
'Attribute',
]
class Attribute:
def __init__(self, model_file=None, session=None):
assert model_file is not None
self.model_file = model_file
self.session = session
find_sub = False
find_mul = False
model = onnx.load(self.model_file)
graph = model.graph
for nid, node in enumerate(graph.node[:8]):
#print(nid, node.name)
if node.name.startswith('Sub') or node.name.startswith('_minus'):
find_sub = True
if node.name.startswith('Mul') or node.name.startswith('_mul'):
find_mul = True
if nid<3 and node.name=='bn_data':
find_sub = True
find_mul = True
if find_sub and find_mul:
#mxnet arcface model
input_mean = 0.0
input_std = 1.0
else:
input_mean = 127.5
input_std = 128.0
self.input_mean = input_mean
self.input_std = input_std
#print('input mean and std:', model_file, self.input_mean, self.input_std)
if self.session is None:
self.session = onnxruntime.InferenceSession(self.model_file, None)
input_cfg = self.session.get_inputs()[0]
input_shape = input_cfg.shape
input_name = input_cfg.name
self.input_size = tuple(input_shape[2:4][::-1])
self.input_shape = input_shape
outputs = self.session.get_outputs()
output_names = []
for out in outputs:
output_names.append(out.name)
self.input_name = input_name
self.output_names = output_names
assert len(self.output_names)==1
output_shape = outputs[0].shape
#print('init output_shape:', output_shape)
if output_shape[1]==3:
self.taskname = 'genderage'
else:
self.taskname = 'attribute_%d'%output_shape[1]
def prepare(self, ctx_id, **kwargs):
if ctx_id<0:
self.session.set_providers(['CPUExecutionProvider'])
def get(self, img, face):
bbox = face.bbox
w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1])
center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2
rotate = 0
_scale = self.input_size[0] / (max(w, h)*1.5)
#print('param:', img.shape, bbox, center, self.input_size, _scale, rotate)
aimg, M = face_align.transform(img, center, self.input_size[0], _scale, rotate)
input_size = tuple(aimg.shape[0:2][::-1])
#assert input_size==self.input_size
blob = cv2.dnn.blobFromImage(aimg, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
pred = self.session.run(self.output_names, {self.input_name : blob})[0][0]
if self.taskname=='genderage':
assert len(pred)==3
gender = np.argmax(pred[:2])
age = int(np.round(pred[2]*100))
face['gender'] = gender
face['age'] = age
return gender, age
else:
return pred
import time
import numpy as np
import onnxruntime
import cv2
import onnx
from onnx import numpy_helper
from ..utils import face_align
class INSwapper():
def __init__(self, model_file=None, session=None):
self.model_file = model_file
self.session = session
model = onnx.load(self.model_file)
graph = model.graph
self.emap = numpy_helper.to_array(graph.initializer[-1])
self.input_mean = 0.0
self.input_std = 255.0
#print('input mean and std:', model_file, self.input_mean, self.input_std)
if self.session is None:
self.session = onnxruntime.InferenceSession(self.model_file, None)
inputs = self.session.get_inputs()
self.input_names = []
for inp in inputs:
self.input_names.append(inp.name)
outputs = self.session.get_outputs()
output_names = []
for out in outputs:
output_names.append(out.name)
self.output_names = output_names
assert len(self.output_names)==1
output_shape = outputs[0].shape
input_cfg = inputs[0]
input_shape = input_cfg.shape
self.input_shape = input_shape
# print('inswapper-shape:', self.input_shape)
self.input_size = tuple(input_shape[2:4][::-1])
def forward(self, img, latent):
img = (img - self.input_mean) / self.input_std
pred = self.session.run(self.output_names, {self.input_names[0]: img, self.input_names[1]: latent})[0]
return pred
def get(self, img, target_face, source_face, paste_back=True):
face_mask = np.zeros((img.shape[0], img.shape[1]), np.uint8)
cv2.fillPoly(face_mask, np.array([target_face.landmark_2d_106[[1,9,10,11,12,13,14,15,16,2,3,4,5,6,7,8,0,24,23,22,21,20,19,18,32,31,30,29,28,27,26,25,17,101,105,104,103,51,49,48,43]].astype('int64')]), 1)
aimg, M = face_align.norm_crop2(img, target_face.kps, self.input_size[0])
blob = cv2.dnn.blobFromImage(aimg, 1.0 / self.input_std, self.input_size,
(self.input_mean, self.input_mean, self.input_mean), swapRB=True)
latent = source_face.normed_embedding.reshape((1,-1))
latent = np.dot(latent, self.emap)
latent /= np.linalg.norm(latent)
pred = self.session.run(self.output_names, {self.input_names[0]: blob, self.input_names[1]: latent})[0]
#print(latent.shape, latent.dtype, pred.shape)
img_fake = pred.transpose((0,2,3,1))[0]
bgr_fake = np.clip(255 * img_fake, 0, 255).astype(np.uint8)[:,:,::-1]
if not paste_back:
return bgr_fake, M
else:
target_img = img
fake_diff = bgr_fake.astype(np.float32) - aimg.astype(np.float32)
fake_diff = np.abs(fake_diff).mean(axis=2)
fake_diff[:2,:] = 0
fake_diff[-2:,:] = 0
fake_diff[:,:2] = 0
fake_diff[:,-2:] = 0
IM = cv2.invertAffineTransform(M)
img_white = np.full((aimg.shape[0],aimg.shape[1]), 255, dtype=np.float32)
bgr_fake = cv2.warpAffine(bgr_fake, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0)
img_white = cv2.warpAffine(img_white, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0)
fake_diff = cv2.warpAffine(fake_diff, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0)
img_white[img_white>20] = 255
fthresh = 10
fake_diff[fake_diff<fthresh] = 0
fake_diff[fake_diff>=fthresh] = 255
img_mask = img_white
mask_h_inds, mask_w_inds = np.where(img_mask==255)
mask_h = np.max(mask_h_inds) - np.min(mask_h_inds)
mask_w = np.max(mask_w_inds) - np.min(mask_w_inds)
mask_size = int(np.sqrt(mask_h*mask_w))
k = max(mask_size//10, 10)
#k = max(mask_size//20, 6)
#k = 6
kernel = np.ones((k,k),np.uint8)
img_mask = cv2.erode(img_mask,kernel,iterations = 1)
kernel = np.ones((2,2),np.uint8)
fake_diff = cv2.dilate(fake_diff,kernel,iterations = 1)
face_mask = cv2.erode(face_mask,np.ones((11,11),np.uint8),iterations = 1)
fake_diff[face_mask==1] = 255
k = max(mask_size//20, 5)
#k = 3
#k = 3
kernel_size = (k, k)
blur_size = tuple(2*i+1 for i in kernel_size)
img_mask = cv2.GaussianBlur(img_mask, blur_size, 0)
k = 5
kernel_size = (k, k)
blur_size = tuple(2*i+1 for i in kernel_size)
fake_diff = cv2.blur(fake_diff, (11,11), 0)
##fake_diff = cv2.GaussianBlur(fake_diff, blur_size, 0)
# print('blur_size: ', blur_size)
# fake_diff = cv2.blur(fake_diff, (21, 21), 0) # blur_size
img_mask /= 255
fake_diff /= 255
# img_mask = fake_diff
img_mask = img_mask*fake_diff
img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
fake_merged = img_mask * bgr_fake + (1-img_mask) * target_img.astype(np.float32)
fake_merged = fake_merged.astype(np.uint8)
return fake_merged
# -*- coding: utf-8 -*-
# @Organization : insightface.ai
# @Author : Jia Guo
# @Time : 2021-05-04
# @Function :
from __future__ import division
import numpy as np
import cv2
import onnx
import onnxruntime
from ..utils import face_align
from ..utils import transform
from ..data import get_object
__all__ = [
'Landmark',
]
class Landmark:
def __init__(self, model_file=None, session=None):
assert model_file is not None
self.model_file = model_file
self.session = session
find_sub = False
find_mul = False
model = onnx.load(self.model_file)
graph = model.graph
for nid, node in enumerate(graph.node[:8]):
#print(nid, node.name)
if node.name.startswith('Sub') or node.name.startswith('_minus'):
find_sub = True
if node.name.startswith('Mul') or node.name.startswith('_mul'):
find_mul = True
if nid<3 and node.name=='bn_data':
find_sub = True
find_mul = True
if find_sub and find_mul:
#mxnet arcface model
input_mean = 0.0
input_std = 1.0
else:
input_mean = 127.5
input_std = 128.0
self.input_mean = input_mean
self.input_std = input_std
#print('input mean and std:', model_file, self.input_mean, self.input_std)
if self.session is None:
self.session = onnxruntime.InferenceSession(self.model_file, None)
input_cfg = self.session.get_inputs()[0]
input_shape = input_cfg.shape
input_name = input_cfg.name
self.input_size = tuple(input_shape[2:4][::-1])
self.input_shape = input_shape
outputs = self.session.get_outputs()
output_names = []
for out in outputs:
output_names.append(out.name)
self.input_name = input_name
self.output_names = output_names
assert len(self.output_names)==1
output_shape = outputs[0].shape
self.require_pose = False
#print('init output_shape:', output_shape)
if output_shape[1]==3309:
self.lmk_dim = 3
self.lmk_num = 68
self.mean_lmk = get_object('meanshape_68.pkl')
self.require_pose = True
else:
self.lmk_dim = 2
self.lmk_num = output_shape[1]//self.lmk_dim
self.taskname = 'landmark_%dd_%d'%(self.lmk_dim, self.lmk_num)
def prepare(self, ctx_id, **kwargs):
if ctx_id<0:
self.session.set_providers(['CPUExecutionProvider'])
def get(self, img, face):
bbox = face.bbox
w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1])
center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2
rotate = 0
_scale = self.input_size[0] / (max(w, h)*1.5)
#print('param:', img.shape, bbox, center, self.input_size, _scale, rotate)
aimg, M = face_align.transform(img, center, self.input_size[0], _scale, rotate)
input_size = tuple(aimg.shape[0:2][::-1])
#assert input_size==self.input_size
blob = cv2.dnn.blobFromImage(aimg, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
pred = self.session.run(self.output_names, {self.input_name : blob})[0][0]
if pred.shape[0] >= 3000:
pred = pred.reshape((-1, 3))
else:
pred = pred.reshape((-1, 2))
if self.lmk_num < pred.shape[0]:
pred = pred[self.lmk_num*-1:,:]
pred[:, 0:2] += 1
pred[:, 0:2] *= (self.input_size[0] // 2)
if pred.shape[1] == 3:
pred[:, 2] *= (self.input_size[0] // 2)
IM = cv2.invertAffineTransform(M)
pred = face_align.trans_points(pred, IM)
face[self.taskname] = pred
if self.require_pose:
P = transform.estimate_affine_matrix_3d23d(self.mean_lmk, pred)
s, R, t = transform.P2sRt(P)
rx, ry, rz = transform.matrix2angle(R)
pose = np.array( [rx, ry, rz], dtype=np.float32 )
face['pose'] = pose #pitch, yaw, roll
return pred
"""
This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/model_store.py
"""
from __future__ import print_function
__all__ = ['get_model_file']
import os
import zipfile
import glob
from ..utils import download, check_sha1
_model_sha1 = {
name: checksum
for checksum, name in [
('95be21b58e29e9c1237f229dae534bd854009ce0', 'arcface_r100_v1'),
('', 'arcface_mfn_v1'),
('39fd1e087a2a2ed70a154ac01fecaa86c315d01b', 'retinaface_r50_v1'),
('2c9de8116d1f448fd1d4661f90308faae34c990a', 'retinaface_mnet025_v1'),
('0db1d07921d005e6c9a5b38e059452fc5645e5a4', 'retinaface_mnet025_v2'),
('7dd8111652b7aac2490c5dcddeb268e53ac643e6', 'genderage_v1'),
]
}
base_repo_url = 'https://insightface.ai/files/'
_url_format = '{repo_url}models/{file_name}.zip'
def short_hash(name):
if name not in _model_sha1:
raise ValueError(
'Pretrained model for {name} is not available.'.format(name=name))
return _model_sha1[name][:8]
def find_params_file(dir_path):
if not os.path.exists(dir_path):
return None
paths = glob.glob("%s/*.params" % dir_path)
if len(paths) == 0:
return None
paths = sorted(paths)
return paths[-1]
def get_model_file(name, root=os.path.join('~', '.insightface', 'models')):
r"""Return location for the pretrained on local file system.
This function will download from online model zoo when model cannot be found or has mismatch.
The root directory will be created if it doesn't exist.
Parameters
----------
name : str
Name of the model.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
Returns
-------
file_path
Path to the requested pretrained model file.
"""
file_name = name
root = os.path.expanduser(root)
dir_path = os.path.join(root, name)
file_path = find_params_file(dir_path)
#file_path = os.path.join(root, file_name + '.params')
sha1_hash = _model_sha1[name]
if file_path is not None:
if check_sha1(file_path, sha1_hash):
return file_path
else:
print(
'Mismatch in the content of model file detected. Downloading again.'
)
else:
print('Model file is not found. Downloading.')
if not os.path.exists(root):
os.makedirs(root)
if not os.path.exists(dir_path):
os.makedirs(dir_path)
zip_file_path = os.path.join(root, file_name + '.zip')
repo_url = base_repo_url
if repo_url[-1] != '/':
repo_url = repo_url + '/'
download(_url_format.format(repo_url=repo_url, file_name=file_name),
path=zip_file_path,
overwrite=True)
with zipfile.ZipFile(zip_file_path) as zf:
zf.extractall(dir_path)
os.remove(zip_file_path)
file_path = find_params_file(dir_path)
if check_sha1(file_path, sha1_hash):
return file_path
else:
raise ValueError(
'Downloaded file has different hash. Please try again.')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment