Commit 87f0d198 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Fix loading checkpoint in hubert preprocessing (#2310)

Summary:
When checkpoint is on GPU device and preprocessing is on CPU, the script will throw an exception error. Fix it to load the model state dictionary into CPU by default.

Pull Request resolved: https://github.com/pytorch/audio/pull/2310

Reviewed By: mthrok

Differential Revision: D35316903

Pulled By: nateanl

fbshipit-source-id: d3e7183400ba133240aa6d205f5c671a421a9fed
parent 3ed39e15
......@@ -19,6 +19,7 @@ from torch.nn import Module
from .common_utils import _get_feat_lens_paths
_LG = logging.getLogger(__name__)
_DEFAULT_DEVICE = torch.device("cpu")
def get_shard_range(num_lines: int, num_rank: int, rank: int) -> Tuple[int, int]:
......@@ -105,16 +106,17 @@ def extract_feature_hubert(
return feat
def _load_state(model: Module, checkpoint_path: Path) -> Module:
def _load_state(model: Module, checkpoint_path: Path, device=_DEFAULT_DEVICE) -> Module:
"""Load weights from HuBERTPretrainModel checkpoint into hubert_pretrain_base model.
Args:
model (Module): The hubert_pretrain_base model.
checkpoint_path (Path): The model checkpoint.
device (torch.device, optional): The device of the model. (Default: ``torch.device("cpu")``)
Returns:
(Module): The pretrained model.
"""
state_dict = torch.load(checkpoint_path)
state_dict = torch.load(checkpoint_path, map_location=device)
state_dict = {k.replace("model.", ""): v for k, v in state_dict["state_dict"].items()}
model.load_state_dict(state_dict)
return model
......@@ -169,8 +171,8 @@ def dump_features(
from torchaudio.models import hubert_pretrain_base
model = hubert_pretrain_base()
model = _load_state(model, checkpoint_path)
model.to(device)
model = _load_state(model, checkpoint_path, device)
with open(tsv_file, "r") as f:
root = f.readline().rstrip()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment