Unverified Commit 0dffdbb4 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Fix conditional import for Nemo LM class (#1641)

parent e9d429e1
...@@ -34,29 +34,17 @@ from lm_eval.utils import ( ...@@ -34,29 +34,17 @@ from lm_eval.utils import (
) )
try:
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import (
MegatronGPTModel,
)
from nemo.collections.nlp.modules.common.text_generation_utils import generate
from nemo.collections.nlp.parts.nlp_overrides import (
NLPDDPStrategy,
NLPSaveRestoreConnector,
)
from nemo.utils.app_state import AppState
from pytorch_lightning.trainer.trainer import Trainer
except ModuleNotFoundError:
raise Exception(
"Attempted to use 'nemo_lm' model type, but package `nemo` is not installed"
"Please install nemo following the instructions in the README: either with a NVIDIA PyTorch or NeMo container, "
"or installing nemo following https://github.com/NVIDIA/NeMo.",
)
def _patch_pretrained_cfg( def _patch_pretrained_cfg(
pretrained_cfg, trainer, tensor_model_parallel_size, pipeline_model_parallel_size pretrained_cfg, trainer, tensor_model_parallel_size, pipeline_model_parallel_size
): ):
import omegaconf try:
import omegaconf
except ModuleNotFoundError:
raise Exception(
"Attempted to use 'nemo_lm' model type, but package `nemo` is not installed"
"Please install nemo following the instructions in the README: either with a NVIDIA PyTorch or NeMo container, "
"or installing nemo following https://github.com/NVIDIA/NeMo.",
)
omegaconf.OmegaConf.set_struct(pretrained_cfg, True) omegaconf.OmegaConf.set_struct(pretrained_cfg, True)
with omegaconf.open_dict(pretrained_cfg): with omegaconf.open_dict(pretrained_cfg):
...@@ -86,6 +74,17 @@ def load_model( ...@@ -86,6 +74,17 @@ def load_model(
tensor_model_parallel_size: int, tensor_model_parallel_size: int,
pipeline_model_parallel_size: int, pipeline_model_parallel_size: int,
) -> torch.nn.Module: ) -> torch.nn.Module:
try:
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import (
MegatronGPTModel,
)
from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector
except ModuleNotFoundError:
raise Exception(
"Attempted to use 'nemo_lm' model type, but package `nemo` is not installed"
"Please install nemo following the instructions in the README: either with a NVIDIA PyTorch or NeMo container, "
"or installing nemo following https://github.com/NVIDIA/NeMo.",
)
model_path = pathlib.Path(model_path) model_path = pathlib.Path(model_path)
save_restore_connector = NLPSaveRestoreConnector() save_restore_connector = NLPSaveRestoreConnector()
...@@ -139,6 +138,15 @@ def load_model( ...@@ -139,6 +138,15 @@ def load_model(
def setup_distributed_environment(trainer): def setup_distributed_environment(trainer):
try:
from nemo.utils.app_state import AppState
except ModuleNotFoundError:
raise Exception(
"Attempted to use 'nemo_lm' model type, but package `nemo` is not installed"
"Please install nemo following the instructions in the README: either with a NVIDIA PyTorch or NeMo container, "
"or installing nemo following https://github.com/NVIDIA/NeMo.",
)
def dummy(): def dummy():
return return
...@@ -178,6 +186,21 @@ class NeMoLM(LM): ...@@ -178,6 +186,21 @@ class NeMoLM(LM):
] = "bf16", ] = "bf16",
**kwargs, **kwargs,
): ):
try:
from nemo.collections.nlp.modules.common.text_generation_utils import (
generate,
)
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy
from pytorch_lightning.trainer.trainer import Trainer
self.generate = generate
except ModuleNotFoundError:
raise Exception(
"Attempted to use 'nemo_lm' model type, but package `nemo` is not installed"
"Please install nemo following the instructions in the README: either with a NVIDIA PyTorch or NeMo container, "
"or installing nemo following https://github.com/NVIDIA/NeMo.",
)
super().__init__() super().__init__()
if ( if (
...@@ -396,7 +419,7 @@ class NeMoLM(LM): ...@@ -396,7 +419,7 @@ class NeMoLM(LM):
inps.append(self.tok_decode(inp)) inps.append(self.tok_decode(inp))
output = generate( output = self.generate(
self.model, self.model,
inputs=inps, inputs=inps,
tokens_to_generate=1, tokens_to_generate=1,
...@@ -490,7 +513,7 @@ class NeMoLM(LM): ...@@ -490,7 +513,7 @@ class NeMoLM(LM):
encoded_context = encoded_context[-remaining_length:] encoded_context = encoded_context[-remaining_length:]
contexts.append(self.tok_decode(encoded_context)) contexts.append(self.tok_decode(encoded_context))
output = generate( output = self.generate(
self.model, self.model,
inputs=contexts, inputs=contexts,
tokens_to_generate=max_gen_toks, tokens_to_generate=max_gen_toks,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment