Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
f264f2e2
Commit
f264f2e2
authored
Jul 22, 2025
by
Baber
Browse files
type hints
parent
230352ce
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
111 additions
and
37 deletions
+111
-37
.pre-commit-config.yaml
.pre-commit-config.yaml
+1
-1
lm_eval/api/task.py
lm_eval/api/task.py
+80
-7
lm_eval/config/metric.py
lm_eval/config/metric.py
+3
-3
lm_eval/config/task.py
lm_eval/config/task.py
+2
-8
lm_eval/utils.py
lm_eval/utils.py
+25
-18
No files found.
.pre-commit-config.yaml
View file @
f264f2e2
...
@@ -33,7 +33,7 @@ repos:
...
@@ -33,7 +33,7 @@ repos:
hooks
:
hooks
:
# Run the linter.
# Run the linter.
-
id
:
ruff-check
-
id
:
ruff-check
args
:
[
--fix
]
args
:
[
--fix
,
--unsafe-fixes
]
# Run the formatter.
# Run the formatter.
-
id
:
ruff-format
-
id
:
ruff-format
-
repo
:
https://github.com/codespell-project/codespell
-
repo
:
https://github.com/codespell-project/codespell
...
...
lm_eval/api/task.py
View file @
f264f2e2
...
@@ -11,6 +11,7 @@ from typing import (
...
@@ -11,6 +11,7 @@ from typing import (
TYPE_CHECKING
,
TYPE_CHECKING
,
Any
,
Any
,
Literal
,
Literal
,
overload
,
)
)
import
datasets
import
datasets
...
@@ -192,7 +193,7 @@ class Task(abc.ABC):
...
@@ -192,7 +193,7 @@ class Task(abc.ABC):
elif
self
.
has_validation_docs
():
elif
self
.
has_validation_docs
():
return
self
.
validation_docs
()
return
self
.
validation_docs
()
else
:
else
:
if
self
.
config
.
get
(
"
num_fewshot
"
,
0
)
>
0
:
if
self
.
config
.
num_fewshot
and
self
.
config
.
num_fewshot
>
0
:
eval_logger
.
warning
(
eval_logger
.
warning
(
f
"[Task:
{
self
.
config
.
task
}
] has_training_docs and has_validation_docs are False"
f
"[Task:
{
self
.
config
.
task
}
] has_training_docs and has_validation_docs are False"
", using test_docs as fewshot_docs but this is not recommended."
", using test_docs as fewshot_docs but this is not recommended."
...
@@ -331,7 +332,7 @@ class Task(abc.ABC):
...
@@ -331,7 +332,7 @@ class Task(abc.ABC):
inst
=
self
.
construct_requests
(
inst
=
self
.
construct_requests
(
doc
=
doc
,
doc
=
doc
,
ctx
=
fewshot_ctx
,
ctx
=
fewshot_ctx
,
metadata
=
(
self
.
config
[
"
task
"
]
,
doc_id
,
self
.
config
.
repeats
),
metadata
=
(
self
.
config
.
task
,
doc_id
,
self
.
config
.
repeats
),
apply_chat_template
=
apply_chat_template
,
apply_chat_template
=
apply_chat_template
,
chat_template
=
chat_template
,
chat_template
=
chat_template
,
)
)
...
@@ -990,9 +991,21 @@ class ConfigurableTask(Task):
...
@@ -990,9 +991,21 @@ class ConfigurableTask(Task):
"""
"""
return
doc
return
doc
@
overload
def
doc_to_text
(
self
,
doc
:
dict
,
doc_to_text
:
None
=
None
)
->
str
|
int
:
...
@
overload
def
doc_to_text
(
self
,
doc
:
dict
,
doc_to_text
:
int
)
->
int
:
...
@
overload
def
doc_to_text
(
self
,
doc
:
dict
,
doc_to_text
:
str
)
->
str
:
...
@
overload
def
doc_to_text
(
self
,
doc
:
dict
,
doc_to_text
:
Callable
[...,
str
])
->
str
:
...
def
doc_to_text
(
def
doc_to_text
(
self
,
doc
:
dict
,
doc_to_text
:
int
|
str
|
Callable
[...,
str
]
|
None
=
None
self
,
doc
:
dict
,
doc_to_text
:
int
|
str
|
Callable
[...,
str
]
|
None
=
None
)
->
str
:
)
->
str
|
int
:
# if self.prompt is not None:
# if self.prompt is not None:
# doc_to_text = self.prompt
# doc_to_text = self.prompt
doc_to_text
=
doc_to_text
or
self
.
config
.
doc_to_text
doc_to_text
=
doc_to_text
or
self
.
config
.
doc_to_text
...
@@ -1025,6 +1038,25 @@ class ConfigurableTask(Task):
...
@@ -1025,6 +1038,25 @@ class ConfigurableTask(Task):
print
(
type
(
doc_to_text
))
print
(
type
(
doc_to_text
))
raise
TypeError
raise
TypeError
@
overload
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
:
None
=
None
)
->
int
|
str
|
list
[
int
]:
...
@
overload
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
:
int
)
->
int
:
...
@
overload
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
:
str
)
->
int
|
str
|
list
[
int
]:
...
@
overload
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
:
list
)
->
list
[
int
]:
...
@
overload
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
:
Callable
[...,
int
|
str
|
list
[
int
]]
)
->
int
|
str
|
list
[
int
]:
...
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
=
None
)
->
int
|
str
|
list
[
int
]:
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
=
None
)
->
int
|
str
|
list
[
int
]:
# if self.prompt is not None:
# if self.prompt is not None:
# doc_to_target = self.prompt
# doc_to_target = self.prompt
...
@@ -1071,6 +1103,23 @@ class ConfigurableTask(Task):
...
@@ -1071,6 +1103,23 @@ class ConfigurableTask(Task):
else
:
else
:
raise
TypeError
raise
TypeError
@
overload
def
doc_to_choice
(
self
,
doc
:
dict
,
doc_to_choice
:
None
=
None
)
->
list
[
str
]:
...
@
overload
def
doc_to_choice
(
self
,
doc
:
dict
,
doc_to_choice
:
str
)
->
list
[
str
]:
...
@
overload
def
doc_to_choice
(
self
,
doc
:
dict
,
doc_to_choice
:
list
)
->
list
[
str
]:
...
@
overload
def
doc_to_choice
(
self
,
doc
:
dict
,
doc_to_choice
:
dict
)
->
list
[
str
]:
...
@
overload
def
doc_to_choice
(
self
,
doc
:
dict
,
doc_to_choice
:
Callable
[...,
list
[
str
]]
)
->
list
[
str
]:
...
def
doc_to_choice
(
def
doc_to_choice
(
self
,
self
,
doc
:
dict
,
doc
:
dict
,
...
@@ -1102,6 +1151,18 @@ class ConfigurableTask(Task):
...
@@ -1102,6 +1151,18 @@ class ConfigurableTask(Task):
else
:
else
:
raise
TypeError
raise
TypeError
@
overload
def
doc_to_image
(
self
,
doc
:
dict
,
doc_to_image
:
None
=
None
)
->
None
:
...
@
overload
def
doc_to_image
(
self
,
doc
:
dict
,
doc_to_image
:
list
)
->
list
:
...
@
overload
def
doc_to_image
(
self
,
doc
:
dict
,
doc_to_image
:
str
)
->
int
|
str
|
None
:
...
@
overload
def
doc_to_image
(
self
,
doc
:
dict
,
doc_to_image
:
Callable
[...,
Any
])
->
Any
:
...
def
doc_to_image
(
self
,
doc
:
dict
,
doc_to_image
=
None
)
->
int
|
str
|
list
|
None
:
def
doc_to_image
(
self
,
doc
:
dict
,
doc_to_image
=
None
)
->
int
|
str
|
list
|
None
:
if
doc_to_image
is
not
None
:
if
doc_to_image
is
not
None
:
doc_to_image
=
doc_to_image
doc_to_image
=
doc_to_image
...
@@ -1125,6 +1186,18 @@ class ConfigurableTask(Task):
...
@@ -1125,6 +1186,18 @@ class ConfigurableTask(Task):
else
:
else
:
return
None
return
None
@
overload
def
doc_to_audio
(
self
,
doc
:
Any
,
doc_to_audio
:
None
=
None
)
->
None
:
...
@
overload
def
doc_to_audio
(
self
,
doc
:
Any
,
doc_to_audio
:
list
)
->
list
:
...
@
overload
def
doc_to_audio
(
self
,
doc
:
Any
,
doc_to_audio
:
str
)
->
int
|
str
|
None
:
...
@
overload
def
doc_to_audio
(
self
,
doc
:
Any
,
doc_to_audio
:
Callable
[...,
Any
])
->
Any
:
...
def
doc_to_audio
(
self
,
doc
:
Any
,
doc_to_audio
=
None
)
->
int
|
str
|
list
|
None
:
def
doc_to_audio
(
self
,
doc
:
Any
,
doc_to_audio
=
None
)
->
int
|
str
|
list
|
None
:
if
doc_to_audio
is
not
None
:
if
doc_to_audio
is
not
None
:
doc_to_audio
=
doc_to_audio
doc_to_audio
=
doc_to_audio
...
@@ -1369,15 +1442,15 @@ class ConfigurableTask(Task):
...
@@ -1369,15 +1442,15 @@ class ConfigurableTask(Task):
elif
self
.
OUTPUT_TYPE
==
"generate_until"
:
elif
self
.
OUTPUT_TYPE
==
"generate_until"
:
gold
=
self
.
doc_to_target
(
doc
)
gold
=
self
.
doc_to_target
(
doc
)
result
=
results
[
0
]
result
=
results
[
0
]
for
metric
in
self
.
_metric_
fn_
list
:
for
metric
in
self
.
config
.
_metric_list
:
try
:
try
:
result_score
=
self
.
_
metric
_
fn
_list
[
metric
]
(
result_score
=
metric
.
fn
(
references
=
[
gold
]
if
not
isinstance
(
gold
,
list
)
else
gold
,
references
=
[
gold
]
if
not
isinstance
(
gold
,
list
)
else
gold
,
predictions
=
[
result
],
predictions
=
[
result
],
**
self
.
_
metric
_fn_
kwargs
[
metric
]
,
**
metric
.
kwargs
,
)
)
except
TypeError
:
# needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
except
TypeError
:
# needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score
=
self
.
_
metric
_
fn
_list
[
metric
]
([
gold
,
result
])
result_score
=
metric
.
fn
([
gold
,
result
])
if
isinstance
(
result_score
,
dict
):
if
isinstance
(
result_score
,
dict
):
# TODO: this handles the case where HF evaluate returns a dict.
# TODO: this handles the case where HF evaluate returns a dict.
# This allows for multiple metrics to be returned from the same function
# This allows for multiple metrics to be returned from the same function
...
...
lm_eval/config/metric.py
View file @
f264f2e2
from
__future__
import
annotations
from
__future__
import
annotations
from
collections.abc
import
Callable
,
Mapping
from
collections.abc
import
Callable
,
Mapping
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
field
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
Any
from
typing
import
Any
...
@@ -11,8 +11,8 @@ class MetricConfig:
...
@@ -11,8 +11,8 @@ class MetricConfig:
"""Encapsulates information about a single metric."""
"""Encapsulates information about a single metric."""
name
:
str
name
:
str
fn
:
Callable
|
None
=
None
fn
:
Callable
kwargs
:
Mapping
[
str
,
Any
]
|
None
=
None
kwargs
:
Mapping
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
aggregation_fn
:
Callable
|
None
=
None
aggregation_fn
:
Callable
|
None
=
None
higher_is_better
:
bool
=
True
higher_is_better
:
bool
=
True
hf_evaluate
:
bool
=
False
hf_evaluate
:
bool
=
False
...
...
lm_eval/config/task.py
View file @
f264f2e2
...
@@ -44,7 +44,7 @@ class FewshotConfig:
...
@@ -44,7 +44,7 @@ class FewshotConfig:
num_fewshot
:
Callable
[[],
int
]
num_fewshot
:
Callable
[[],
int
]
split
:
str
|
None
=
None
split
:
str
|
None
=
None
sampler
:
str
|
Callable
=
"default"
sampler
:
str
|
Callable
=
"default"
samples
:
Callable
[[],
list
[
dict
]]
|
list
[
dict
]
|
None
=
None
samples
:
Callable
[[],
Iterable
[
dict
]]
|
Iterable
[
dict
]
|
None
=
None
process_docs
:
Callable
[[
list
[
dict
[
str
,
Any
]]],
Iterable
[
dict
[
str
,
Any
]]]
|
None
=
(
process_docs
:
Callable
[[
list
[
dict
[
str
,
Any
]]],
Iterable
[
dict
[
str
,
Any
]]]
|
None
=
(
None
None
)
)
...
@@ -71,7 +71,7 @@ class FewshotConfig:
...
@@ -71,7 +71,7 @@ class FewshotConfig:
def
_get_raw_docs
(
def
_get_raw_docs
(
self
,
dataset
self
,
dataset
)
->
list
[
dict
]
|
Callable
[[],
Iterable
[
dict
]]
|
None
:
)
->
list
[
dict
]
|
Callable
[[],
Iterable
[
dict
[
str
,
Any
]
]]
|
None
:
"""Get raw documents from configured source."""
"""Get raw documents from configured source."""
if
self
.
split
is
not
None
:
if
self
.
split
is
not
None
:
return
dataset
[
self
.
split
]
return
dataset
[
self
.
split
]
...
@@ -425,12 +425,6 @@ class TaskConfig:
...
@@ -425,12 +425,6 @@ class TaskConfig:
# Create and return TaskConfig instance
# Create and return TaskConfig instance
return
cls
(
**
config_dict
)
return
cls
(
**
config_dict
)
def
__getitem__
(
self
,
item
):
return
getattr
(
self
,
item
)
def
__setitem__
(
self
,
item
,
value
):
return
setattr
(
self
,
item
,
value
)
def
to_dict
(
self
,
keep_callable
:
bool
=
False
)
->
dict
:
def
to_dict
(
self
,
keep_callable
:
bool
=
False
)
->
dict
:
def
_ser
(
x
):
def
_ser
(
x
):
if
isinstance
(
x
,
dict
):
if
isinstance
(
x
,
dict
):
...
...
lm_eval/utils.py
View file @
f264f2e2
import
collections
import
collections
import
fnmatch
import
fnmatch
import
functools
import
hashlib
import
hashlib
import
importlib.util
import
importlib.util
import
inspect
import
inspect
...
@@ -8,10 +7,12 @@ import json
...
@@ -8,10 +7,12 @@ import json
import
logging
import
logging
import
os
import
os
import
re
import
re
from
collections.abc
import
Generator
from
dataclasses
import
asdict
,
is_dataclass
from
dataclasses
import
asdict
,
is_dataclass
from
functools
import
lru_cache
,
partial
,
wraps
from
itertools
import
islice
from
itertools
import
islice
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Generator
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Callable
,
Optional
import
numpy
as
np
import
numpy
as
np
import
yaml
import
yaml
...
@@ -91,7 +92,7 @@ def escaped_split(text, sep_char, maxsplit=-1):
...
@@ -91,7 +92,7 @@ def escaped_split(text, sep_char, maxsplit=-1):
return
text
return
text
maxsplit
=
max
(
0
,
maxsplit
)
maxsplit
=
max
(
0
,
maxsplit
)
return
re
.
split
(
r
"(?<!\\)"
+
sep_char
,
text
,
maxsplit
)
return
re
.
split
(
r
"(?<!\\)"
+
sep_char
,
text
,
maxsplit
=
maxsplit
)
def
handle_arg_string
(
arg
):
def
handle_arg_string
(
arg
):
...
@@ -108,7 +109,7 @@ def handle_arg_string(arg):
...
@@ -108,7 +109,7 @@ def handle_arg_string(arg):
def
handle_non_serializable
(
o
):
def
handle_non_serializable
(
o
):
if
isinstance
(
o
,
np
.
int
64
)
or
isinstance
(
o
,
np
.
int32
):
if
isinstance
(
o
,
np
.
int
eger
):
return
int
(
o
)
return
int
(
o
)
elif
isinstance
(
o
,
set
):
elif
isinstance
(
o
,
set
):
return
list
(
o
)
return
list
(
o
)
...
@@ -218,21 +219,21 @@ def sanitize_task_name(task_name: str) -> str:
...
@@ -218,21 +219,21 @@ def sanitize_task_name(task_name: str) -> str:
return
re
.
sub
(
r
"\W"
,
"_"
,
task_name
)
return
re
.
sub
(
r
"\W"
,
"_"
,
task_name
)
def
get_latest_filename
(
filenames
:
L
ist
[
str
])
->
str
:
def
get_latest_filename
(
filenames
:
l
ist
[
str
])
->
str
:
"""
"""
Given a list of filenames, returns the filename with the latest datetime.
Given a list of filenames, returns the filename with the latest datetime.
"""
"""
return
max
(
filenames
,
key
=
lambda
f
:
get_file_datetime
(
f
))
return
max
(
filenames
,
key
=
lambda
f
:
get_file_datetime
(
f
))
def
get_results_filenames
(
filenames
:
L
ist
[
str
])
->
L
ist
[
str
]:
def
get_results_filenames
(
filenames
:
l
ist
[
str
])
->
l
ist
[
str
]:
"""
"""
Extracts filenames that correspond to aggregated results.
Extracts filenames that correspond to aggregated results.
"""
"""
return
[
f
for
f
in
filenames
if
"/results_"
in
f
and
".json"
in
f
]
return
[
f
for
f
in
filenames
if
"/results_"
in
f
and
".json"
in
f
]
def
get_sample_results_filenames
(
filenames
:
L
ist
[
str
])
->
L
ist
[
str
]:
def
get_sample_results_filenames
(
filenames
:
l
ist
[
str
])
->
l
ist
[
str
]:
"""
"""
Extracts filenames that correspond to sample results.
Extracts filenames that correspond to sample results.
"""
"""
...
@@ -240,8 +241,8 @@ def get_sample_results_filenames(filenames: List[str]) -> List[str]:
...
@@ -240,8 +241,8 @@ def get_sample_results_filenames(filenames: List[str]) -> List[str]:
def
get_rolling_token_windows
(
def
get_rolling_token_windows
(
token_list
:
L
ist
[
int
],
prefix_token
:
int
,
max_seq_len
:
int
,
context_len
:
int
token_list
:
l
ist
[
int
],
prefix_token
:
int
,
max_seq_len
:
int
,
context_len
:
int
)
->
Generator
[
T
uple
[
L
ist
[
int
],
L
ist
[
int
]],
None
,
None
]:
)
->
Generator
[
t
uple
[
l
ist
[
int
],
l
ist
[
int
]],
None
,
None
]:
"""
"""
- context_len allows for a rolling window context, allowing each prediction window to potentially
- context_len allows for a rolling window context, allowing each prediction window to potentially
condition on some context
condition on some context
...
@@ -283,8 +284,8 @@ def get_rolling_token_windows(
...
@@ -283,8 +284,8 @@ def get_rolling_token_windows(
def
make_disjoint_window
(
def
make_disjoint_window
(
pair
:
T
uple
[
L
ist
[
int
],
L
ist
[
int
]],
pair
:
t
uple
[
l
ist
[
int
],
l
ist
[
int
]],
)
->
T
uple
[
L
ist
[
int
],
L
ist
[
int
]]:
)
->
t
uple
[
l
ist
[
int
],
l
ist
[
int
]]:
"""Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
"""Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
a
,
b
=
pair
a
,
b
=
pair
return
a
[:
len
(
a
)
-
(
len
(
b
)
-
1
)],
b
return
a
[:
len
(
a
)
-
(
len
(
b
)
-
1
)],
b
...
@@ -303,7 +304,7 @@ class EnhancedJSONEncoder(json.JSONEncoder):
...
@@ -303,7 +304,7 @@ class EnhancedJSONEncoder(json.JSONEncoder):
class
Reorderer
:
class
Reorderer
:
def
__init__
(
self
,
arr
:
L
ist
[
Any
],
fn
:
Callable
)
->
None
:
def
__init__
(
self
,
arr
:
l
ist
[
Any
],
fn
:
Callable
)
->
None
:
"""Reorder an array according to some function
"""Reorder an array according to some function
Args:
Args:
...
@@ -406,11 +407,11 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False)
...
@@ -406,11 +407,11 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False)
# TODO: fix
# TODO: fix
hib
=
"↑"
hib
=
"↑"
v
=
"%.4f"
%
v
if
isinstance
(
v
,
float
)
else
v
v
=
f
"
{
v
:.
4
f
}
"
if
isinstance
(
v
,
float
)
else
v
if
m
+
"_stderr"
+
","
+
f
in
dic
:
if
m
+
"_stderr"
+
","
+
f
in
dic
:
se
=
dic
[
m
+
"_stderr"
+
","
+
f
]
se
=
dic
[
m
+
"_stderr"
+
","
+
f
]
se
=
" N/A"
if
se
==
"N/A"
else
"%.4f"
%
se
se
=
" N/A"
if
se
==
"N/A"
else
f
"
{
se
:.
4
f
}
"
values
.
append
([
k
,
version
,
f
,
n
,
m
,
hib
,
v
,
"±"
,
se
])
values
.
append
([
k
,
version
,
f
,
n
,
m
,
hib
,
v
,
"±"
,
se
])
else
:
else
:
values
.
append
([
k
,
version
,
f
,
n
,
m
,
hib
,
v
,
""
,
""
])
values
.
append
([
k
,
version
,
f
,
n
,
m
,
hib
,
v
,
""
,
""
])
...
@@ -431,7 +432,8 @@ def positional_deprecated(fn):
...
@@ -431,7 +432,8 @@ def positional_deprecated(fn):
wrapped function, `fn`.
wrapped function, `fn`.
"""
"""
@
functools
.
wraps
(
fn
)
wraps
(
fn
)
def
_wrapper
(
*
args
,
**
kwargs
):
def
_wrapper
(
*
args
,
**
kwargs
):
if
len
(
args
)
!=
1
if
inspect
.
ismethod
(
fn
)
else
0
:
if
len
(
args
)
!=
1
if
inspect
.
ismethod
(
fn
)
else
0
:
print
(
print
(
...
@@ -477,7 +479,7 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full
...
@@ -477,7 +479,7 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full
if
yaml_path
is
None
:
if
yaml_path
is
None
:
raise
ValueError
(
"yaml_path must be provided if mode is 'full'."
)
raise
ValueError
(
"yaml_path must be provided if mode is 'full'."
)
# Attach yaml_path to the import function so that it can be used later
# Attach yaml_path to the import function so that it can be used later
constructor_fn
=
functools
.
partial
(
import_function
,
yaml_path
=
Path
(
yaml_path
))
constructor_fn
=
partial
(
import_function
,
yaml_path
=
Path
(
yaml_path
))
loader
=
yaml
.
CLoader
if
yaml
.
__with_libyaml__
else
yaml
.
FullLoader
loader
=
yaml
.
CLoader
if
yaml
.
__with_libyaml__
else
yaml
.
FullLoader
# Add the import_function constructor to the YAML loader
# Add the import_function constructor to the YAML loader
...
@@ -526,13 +528,18 @@ def regex_replace(string, pattern, repl, count: int = 0):
...
@@ -526,13 +528,18 @@ def regex_replace(string, pattern, repl, count: int = 0):
env
=
Environment
(
env
=
Environment
(
loader
=
BaseLoader
,
undefined
=
StrictUndefined
,
keep_trailing_newline
=
True
loader
=
BaseLoader
()
,
undefined
=
StrictUndefined
,
keep_trailing_newline
=
True
)
)
env
.
filters
[
"regex_replace"
]
=
regex_replace
env
.
filters
[
"regex_replace"
]
=
regex_replace
@
lru_cache
(
maxsize
=
128
)
def
_compile
(
raw
:
str
):
return
env
.
from_string
(
raw
)
def
apply_template
(
template
:
str
,
doc
:
dict
)
->
str
:
def
apply_template
(
template
:
str
,
doc
:
dict
)
->
str
:
rtemplate
=
env
.
from_string
(
template
)
rtemplate
=
_compile
(
template
)
return
rtemplate
.
render
(
**
doc
)
return
rtemplate
.
render
(
**
doc
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment