Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
61fc5bfd
Commit
61fc5bfd
authored
Apr 23, 2025
by
artemorloff
Browse files
pre-commit prettify
parent
c1e43393
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
21 deletions
+33
-21
lm_eval/__main__.py
lm_eval/__main__.py
+6
-6
lm_eval/api/eval_config.py
lm_eval/api/eval_config.py
+12
-8
lm_eval/evaluator.py
lm_eval/evaluator.py
+15
-7
No files found.
lm_eval/__main__.py
View file @
61fc5bfd
...
@@ -8,22 +8,22 @@ from pathlib import Path
...
@@ -8,22 +8,22 @@ from pathlib import Path
from
typing
import
Union
from
typing
import
Union
from
lm_eval
import
evaluator
,
utils
from
lm_eval
import
evaluator
,
utils
from
lm_eval.api.eval_config
import
(
EvaluationConfig
,
TrackExplicitAction
,
TrackExplicitStoreTrue
,
)
# from lm_eval.evaluator import request_caching_arg_to_dict
# from lm_eval.evaluator import request_caching_arg_to_dict
from
lm_eval.loggers
import
EvaluationTracker
,
WandbLogger
from
lm_eval.loggers
import
EvaluationTracker
,
WandbLogger
from
lm_eval.tasks
import
TaskManager
from
lm_eval.tasks
import
TaskManager
from
lm_eval.utils
import
(
from
lm_eval.utils
import
(
handle_non_serializable
,
handle_non_serializable
,
load_yaml_config
,
make_table
,
make_table
,
request_caching_arg_to_dict
,
request_caching_arg_to_dict
,
# non_default_update,
# non_default_update,
# parse_namespace,
# parse_namespace,
)
)
from
lm_eval.api.eval_config
import
(
TrackExplicitAction
,
TrackExplicitStoreTrue
,
EvaluationConfig
,
)
def
try_parse_json
(
value
:
str
)
->
Union
[
str
,
dict
,
None
]:
def
try_parse_json
(
value
:
str
)
->
Union
[
str
,
dict
,
None
]:
...
...
lm_eval/api/eval_config.py
View file @
61fc5bfd
import
argparse
import
os
import
os
import
yaml
from
argparse
import
Namespace
from
argparse
import
Namespace
from
typing
import
Any
,
Dict
,
Union
,
Optional
from
typing
import
Any
,
Dict
,
Optional
,
Union
import
argparse
import
yaml
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
from
lm_eval.utils
import
simple_parse_args_string
from
lm_eval.utils
import
simple_parse_args_string
...
@@ -21,6 +23,7 @@ class EvaluationConfig(BaseModel):
...
@@ -21,6 +23,7 @@ class EvaluationConfig(BaseModel):
Simple config container for language-model evaluation.
Simple config container for language-model evaluation.
No content validation here—just holds whatever comes from YAML or CLI.
No content validation here—just holds whatever comes from YAML or CLI.
"""
"""
config
:
Optional
[
str
]
config
:
Optional
[
str
]
model
:
Optional
[
str
]
model
:
Optional
[
str
]
model_args
:
Optional
[
dict
]
model_args
:
Optional
[
dict
]
...
@@ -54,7 +57,6 @@ class EvaluationConfig(BaseModel):
...
@@ -54,7 +57,6 @@ class EvaluationConfig(BaseModel):
metadata
:
Optional
[
dict
]
metadata
:
Optional
[
dict
]
request_caching_args
:
Optional
[
dict
]
=
None
request_caching_args
:
Optional
[
dict
]
=
None
@
staticmethod
@
staticmethod
def
parse_namespace
(
namespace
:
argparse
.
Namespace
)
->
Dict
[
str
,
Any
]:
def
parse_namespace
(
namespace
:
argparse
.
Namespace
)
->
Dict
[
str
,
Any
]:
"""
"""
...
@@ -90,7 +92,6 @@ class EvaluationConfig(BaseModel):
...
@@ -90,7 +92,6 @@ class EvaluationConfig(BaseModel):
return
config
,
non_default_args
return
config
,
non_default_args
@
staticmethod
@
staticmethod
def
non_default_update
(
console_dict
,
local_dict
,
non_default_args
):
def
non_default_update
(
console_dict
,
local_dict
,
non_default_args
):
"""
"""
...
@@ -117,7 +118,6 @@ class EvaluationConfig(BaseModel):
...
@@ -117,7 +118,6 @@ class EvaluationConfig(BaseModel):
return
result_config
return
result_config
@
classmethod
@
classmethod
def
from_cli
(
cls
,
namespace
:
Namespace
)
->
"EvaluationConfig"
:
def
from_cli
(
cls
,
namespace
:
Namespace
)
->
"EvaluationConfig"
:
"""
"""
...
@@ -142,7 +142,9 @@ class EvaluationConfig(BaseModel):
...
@@ -142,7 +142,9 @@ class EvaluationConfig(BaseModel):
except
yaml
.
YAMLError
as
e
:
except
yaml
.
YAMLError
as
e
:
raise
ValueError
(
f
"Invalid YAML in
{
cfg_path
}
:
{
e
}
"
)
raise
ValueError
(
f
"Invalid YAML in
{
cfg_path
}
:
{
e
}
"
)
if
not
isinstance
(
yaml_data
,
dict
):
if
not
isinstance
(
yaml_data
,
dict
):
raise
ValueError
(
f
"YAML root must be a mapping, got
{
type
(
yaml_data
).
__name__
}
"
)
raise
ValueError
(
f
"YAML root must be a mapping, got
{
type
(
yaml_data
).
__name__
}
"
)
config_data
.
update
(
yaml_data
)
config_data
.
update
(
yaml_data
)
# 3. Override with any CLI args the user explicitly passed
# 3. Override with any CLI args the user explicitly passed
...
@@ -153,7 +155,9 @@ class EvaluationConfig(BaseModel):
...
@@ -153,7 +155,9 @@ class EvaluationConfig(BaseModel):
# config_data[key] = val
# config_data[key] = val
print
(
f
"YAML:
{
config_data
}
"
)
print
(
f
"YAML:
{
config_data
}
"
)
print
(
f
"CLI:
{
args_dict
}
"
)
print
(
f
"CLI:
{
args_dict
}
"
)
dict_config
=
EvaluationConfig
.
non_default_update
(
args_dict
,
config_data
,
explicit_args
)
dict_config
=
EvaluationConfig
.
non_default_update
(
args_dict
,
config_data
,
explicit_args
)
# 4. Instantiate the Pydantic model (no further validation here)
# 4. Instantiate the Pydantic model (no further validation here)
return
cls
(
**
dict_config
)
return
cls
(
**
dict_config
)
...
...
lm_eval/evaluator.py
View file @
61fc5bfd
...
@@ -4,7 +4,7 @@ import logging
...
@@ -4,7 +4,7 @@ import logging
import
random
import
random
import
time
import
time
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -13,6 +13,7 @@ import lm_eval.api.metrics
...
@@ -13,6 +13,7 @@ import lm_eval.api.metrics
import
lm_eval.api.registry
import
lm_eval.api.registry
import
lm_eval.api.task
import
lm_eval.api.task
import
lm_eval.models
import
lm_eval.models
from
lm_eval.api.eval_config
import
EvaluationConfig
from
lm_eval.caching.cache
import
delete_cache
from
lm_eval.caching.cache
import
delete_cache
from
lm_eval.evaluator_utils
import
(
from
lm_eval.evaluator_utils
import
(
consolidate_group_results
,
consolidate_group_results
,
...
@@ -34,7 +35,6 @@ from lm_eval.utils import (
...
@@ -34,7 +35,6 @@ from lm_eval.utils import (
setup_logging
,
setup_logging
,
simple_parse_args_string
,
simple_parse_args_string
,
)
)
from
lm_eval.api.eval_config
import
EvaluationConfig
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -215,7 +215,9 @@ def simple_evaluate(
...
@@ -215,7 +215,9 @@ def simple_evaluate(
lm
=
config
.
model
lm
=
config
.
model
if
config
.
use_cache
is
not
None
:
if
config
.
use_cache
is
not
None
:
eval_logger
.
info
(
f
"Using cache at
{
config
.
use_cache
+
'_rank'
+
str
(
lm
.
rank
)
+
'.db'
}
"
)
eval_logger
.
info
(
f
"Using cache at
{
config
.
use_cache
+
'_rank'
+
str
(
lm
.
rank
)
+
'.db'
}
"
)
lm
=
lm_eval
.
api
.
model
.
CachingLM
(
lm
=
lm_eval
.
api
.
model
.
CachingLM
(
lm
,
lm
,
config
.
use_cache
config
.
use_cache
...
@@ -249,7 +251,9 @@ def simple_evaluate(
...
@@ -249,7 +251,9 @@ def simple_evaluate(
if
task_obj
.
get_config
(
"output_type"
)
==
"generate_until"
:
if
task_obj
.
get_config
(
"output_type"
)
==
"generate_until"
:
if
config
.
gen_kwargs
is
not
None
:
if
config
.
gen_kwargs
is
not
None
:
task_obj
.
set_config
(
task_obj
.
set_config
(
key
=
"generation_kwargs"
,
value
=
config
.
gen_kwargs
,
update
=
True
key
=
"generation_kwargs"
,
value
=
config
.
gen_kwargs
,
update
=
True
,
)
)
eval_logger
.
info
(
eval_logger
.
info
(
f
"
{
task_obj
.
config
.
task
}
: Using gen_kwargs:
{
task_obj
.
config
.
generation_kwargs
}
"
f
"
{
task_obj
.
config
.
task
}
: Using gen_kwargs:
{
task_obj
.
config
.
generation_kwargs
}
"
...
@@ -271,7 +275,7 @@ def simple_evaluate(
...
@@ -271,7 +275,7 @@ def simple_evaluate(
)
)
else
:
else
:
eval_logger
.
warning
(
eval_logger
.
warning
(
f
"Overwriting default num_fewshot of
{
task_name
}
from
{
default_num_fewshot
}
to
{
num_fewshot
}
"
f
"Overwriting default num_fewshot of
{
task_name
}
from
{
default_num_fewshot
}
to
{
config
.
num_fewshot
}
"
)
)
task_obj
.
set_config
(
key
=
"num_fewshot"
,
value
=
config
.
num_fewshot
)
task_obj
.
set_config
(
key
=
"num_fewshot"
,
value
=
config
.
num_fewshot
)
else
:
else
:
...
@@ -309,7 +313,9 @@ def simple_evaluate(
...
@@ -309,7 +313,9 @@ def simple_evaluate(
limit
=
config
.
limit
,
limit
=
config
.
limit
,
samples
=
config
.
samples
,
samples
=
config
.
samples
,
cache_requests
=
config
.
cache_requests
,
cache_requests
=
config
.
cache_requests
,
rewrite_requests_cache
=
config
.
request_caching_args
.
get
(
"rewrite_requests_cache"
,
False
),
rewrite_requests_cache
=
config
.
request_caching_args
.
get
(
"rewrite_requests_cache"
,
False
),
bootstrap_iters
=
bootstrap_iters
,
bootstrap_iters
=
bootstrap_iters
,
write_out
=
config
.
write_out
,
write_out
=
config
.
write_out
,
log_samples
=
True
if
config
.
predict_only
else
config
.
log_samples
,
log_samples
=
True
if
config
.
predict_only
else
config
.
log_samples
,
...
@@ -325,7 +331,9 @@ def simple_evaluate(
...
@@ -325,7 +331,9 @@ def simple_evaluate(
if
lm
.
rank
==
0
:
if
lm
.
rank
==
0
:
if
isinstance
(
config
.
model
,
str
):
if
isinstance
(
config
.
model
,
str
):
model_name
=
config
.
model
model_name
=
config
.
model
elif
hasattr
(
config
.
model
,
"config"
)
and
hasattr
(
config
.
model
.
config
,
"_name_or_path"
):
elif
hasattr
(
config
.
model
,
"config"
)
and
hasattr
(
config
.
model
.
config
,
"_name_or_path"
):
model_name
=
config
.
model
.
config
.
_name_or_path
model_name
=
config
.
model
.
config
.
_name_or_path
else
:
else
:
model_name
=
type
(
config
.
model
).
__name__
model_name
=
type
(
config
.
model
).
__name__
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment