Commit 10195316 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Make lazy import for joblib (#2498)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/2498

Reviewed By: mthrok

Differential Revision: D37224024

Pulled By: nateanl

fbshipit-source-id: 5d5d561c43d1ee323ae0cc599ffa1479208ea09a
parent 74dcfba3
...@@ -7,7 +7,6 @@ import logging ...@@ -7,7 +7,6 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Tuple from typing import Tuple
import joblib
import torch import torch
from sklearn.cluster import MiniBatchKMeans from sklearn.cluster import MiniBatchKMeans
from torch import Tensor from torch import Tensor
...@@ -107,6 +106,8 @@ def learn_kmeans( ...@@ -107,6 +106,8 @@ def learn_kmeans(
feats = feats.numpy() feats = feats.numpy()
km_model.fit(feats) km_model.fit(feats)
km_path = _get_model_path(km_dir) km_path = _get_model_path(km_dir)
import joblib
joblib.dump(km_model, km_path) joblib.dump(km_model, km_path)
inertia = -km_model.score(feats) / len(feats) inertia = -km_model.score(feats) / len(feats)
...@@ -116,6 +117,8 @@ def learn_kmeans( ...@@ -116,6 +117,8 @@ def learn_kmeans(
class ApplyKmeans(object): class ApplyKmeans(object):
def __init__(self, km_path, device): def __init__(self, km_path, device):
import joblib
self.km_model = joblib.load(km_path) self.km_model = joblib.load(km_path)
self.C_np = self.km_model.cluster_centers_.transpose() self.C_np = self.km_model.cluster_centers_.transpose()
self.Cnorm_np = (self.C_np**2).sum(0, keepdims=True) self.Cnorm_np = (self.C_np**2).sum(0, keepdims=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment