Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
0dffdbb4
"magic_pdf/vscode:/vscode.git/clone" did not exist on "a48f1d1485bdf72f7c485a70634acd189cc3ffa1"
Unverified
Commit
0dffdbb4
authored
Mar 27, 2024
by
Hailey Schoelkopf
Committed by
GitHub
Mar 27, 2024
Browse files
Fix conditional import for Nemo LM class (#1641)
parent
e9d429e1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
45 additions
and
22 deletions
+45
-22
lm_eval/models/nemo_lm.py
lm_eval/models/nemo_lm.py
+45
-22
No files found.
lm_eval/models/nemo_lm.py
View file @
0dffdbb4
...
...
@@ -34,29 +34,17 @@ from lm_eval.utils import (
)
try
:
from
nemo.collections.nlp.models.language_modeling.megatron_gpt_model
import
(
MegatronGPTModel
,
)
from
nemo.collections.nlp.modules.common.text_generation_utils
import
generate
from
nemo.collections.nlp.parts.nlp_overrides
import
(
NLPDDPStrategy
,
NLPSaveRestoreConnector
,
)
from
nemo.utils.app_state
import
AppState
from
pytorch_lightning.trainer.trainer
import
Trainer
except
ModuleNotFoundError
:
raise
Exception
(
"Attempted to use 'nemo_lm' model type, but package `nemo` is not installed"
"Please install nemo following the instructions in the README: either with a NVIDIA PyTorch or NeMo container, "
"or installing nemo following https://github.com/NVIDIA/NeMo."
,
)
def
_patch_pretrained_cfg
(
pretrained_cfg
,
trainer
,
tensor_model_parallel_size
,
pipeline_model_parallel_size
):
import
omegaconf
try
:
import
omegaconf
except
ModuleNotFoundError
:
raise
Exception
(
"Attempted to use 'nemo_lm' model type, but package `nemo` is not installed"
"Please install nemo following the instructions in the README: either with a NVIDIA PyTorch or NeMo container, "
"or installing nemo following https://github.com/NVIDIA/NeMo."
,
)
omegaconf
.
OmegaConf
.
set_struct
(
pretrained_cfg
,
True
)
with
omegaconf
.
open_dict
(
pretrained_cfg
):
...
...
@@ -86,6 +74,17 @@ def load_model(
tensor_model_parallel_size
:
int
,
pipeline_model_parallel_size
:
int
,
)
->
torch
.
nn
.
Module
:
try
:
from
nemo.collections.nlp.models.language_modeling.megatron_gpt_model
import
(
MegatronGPTModel
,
)
from
nemo.collections.nlp.parts.nlp_overrides
import
NLPSaveRestoreConnector
except
ModuleNotFoundError
:
raise
Exception
(
"Attempted to use 'nemo_lm' model type, but package `nemo` is not installed"
"Please install nemo following the instructions in the README: either with a NVIDIA PyTorch or NeMo container, "
"or installing nemo following https://github.com/NVIDIA/NeMo."
,
)
model_path
=
pathlib
.
Path
(
model_path
)
save_restore_connector
=
NLPSaveRestoreConnector
()
...
...
@@ -139,6 +138,15 @@ def load_model(
def
setup_distributed_environment
(
trainer
):
try
:
from
nemo.utils.app_state
import
AppState
except
ModuleNotFoundError
:
raise
Exception
(
"Attempted to use 'nemo_lm' model type, but package `nemo` is not installed"
"Please install nemo following the instructions in the README: either with a NVIDIA PyTorch or NeMo container, "
"or installing nemo following https://github.com/NVIDIA/NeMo."
,
)
def
dummy
():
return
...
...
@@ -178,6 +186,21 @@ class NeMoLM(LM):
]
=
"bf16"
,
**
kwargs
,
):
try
:
from
nemo.collections.nlp.modules.common.text_generation_utils
import
(
generate
,
)
from
nemo.collections.nlp.parts.nlp_overrides
import
NLPDDPStrategy
from
pytorch_lightning.trainer.trainer
import
Trainer
self
.
generate
=
generate
except
ModuleNotFoundError
:
raise
Exception
(
"Attempted to use 'nemo_lm' model type, but package `nemo` is not installed"
"Please install nemo following the instructions in the README: either with a NVIDIA PyTorch or NeMo container, "
"or installing nemo following https://github.com/NVIDIA/NeMo."
,
)
super
().
__init__
()
if
(
...
...
@@ -396,7 +419,7 @@ class NeMoLM(LM):
inps
.
append
(
self
.
tok_decode
(
inp
))
output
=
generate
(
output
=
self
.
generate
(
self
.
model
,
inputs
=
inps
,
tokens_to_generate
=
1
,
...
...
@@ -490,7 +513,7 @@ class NeMoLM(LM):
encoded_context
=
encoded_context
[
-
remaining_length
:]
contexts
.
append
(
self
.
tok_decode
(
encoded_context
))
output
=
generate
(
output
=
self
.
generate
(
self
.
model
,
inputs
=
contexts
,
tokens_to_generate
=
max_gen_toks
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment