Commit 87f0d198 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Fix loading checkpoint in hubert preprocessing (#2310)

Summary:
When checkpoint is on GPU device and preprocessing is on CPU, the script will throw an exception error. Fix it to load the model state dictionary into CPU by default.

Pull Request resolved: https://github.com/pytorch/audio/pull/2310

Reviewed By: mthrok

Differential Revision: D35316903

Pulled By: nateanl

fbshipit-source-id: d3e7183400ba133240aa6d205f5c671a421a9fed
parent 3ed39e15
...@@ -19,6 +19,7 @@ from torch.nn import Module ...@@ -19,6 +19,7 @@ from torch.nn import Module
from .common_utils import _get_feat_lens_paths from .common_utils import _get_feat_lens_paths
_LG = logging.getLogger(__name__) _LG = logging.getLogger(__name__)
_DEFAULT_DEVICE = torch.device("cpu")
def get_shard_range(num_lines: int, num_rank: int, rank: int) -> Tuple[int, int]: def get_shard_range(num_lines: int, num_rank: int, rank: int) -> Tuple[int, int]:
...@@ -105,16 +106,17 @@ def extract_feature_hubert( ...@@ -105,16 +106,17 @@ def extract_feature_hubert(
return feat return feat
def _load_state(model: Module, checkpoint_path: Path) -> Module: def _load_state(model: Module, checkpoint_path: Path, device=_DEFAULT_DEVICE) -> Module:
"""Load weights from HuBERTPretrainModel checkpoint into hubert_pretrain_base model. """Load weights from HuBERTPretrainModel checkpoint into hubert_pretrain_base model.
Args: Args:
model (Module): The hubert_pretrain_base model. model (Module): The hubert_pretrain_base model.
checkpoint_path (Path): The model checkpoint. checkpoint_path (Path): The model checkpoint.
device (torch.device, optional): The device of the model. (Default: ``torch.device("cpu")``)
Returns: Returns:
(Module): The pretrained model. (Module): The pretrained model.
""" """
state_dict = torch.load(checkpoint_path) state_dict = torch.load(checkpoint_path, map_location=device)
state_dict = {k.replace("model.", ""): v for k, v in state_dict["state_dict"].items()} state_dict = {k.replace("model.", ""): v for k, v in state_dict["state_dict"].items()}
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
return model return model
...@@ -169,8 +171,8 @@ def dump_features( ...@@ -169,8 +171,8 @@ def dump_features(
from torchaudio.models import hubert_pretrain_base from torchaudio.models import hubert_pretrain_base
model = hubert_pretrain_base() model = hubert_pretrain_base()
model = _load_state(model, checkpoint_path)
model.to(device) model.to(device)
model = _load_state(model, checkpoint_path, device)
with open(tsv_file, "r") as f: with open(tsv_file, "r") as f:
root = f.readline().rstrip() root = f.readline().rstrip()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment