Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
5b8a7506
Commit
5b8a7506
authored
Jun 23, 2025
by
Baber
Browse files
remove other schemas. work on metrics
parent
ba1d4483
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
118 additions
and
115 deletions
+118
-115
lm_eval/api/filter.py
lm_eval/api/filter.py
+7
-8
lm_eval/api/instance.py
lm_eval/api/instance.py
+8
-7
lm_eval/api/model.py
lm_eval/api/model.py
+7
-7
lm_eval/api/schemas.py
lm_eval/api/schemas.py
+62
-57
lm_eval/api/task.py
lm_eval/api/task.py
+20
-13
lm_eval/evaluator.py
lm_eval/evaluator.py
+1
-3
lm_eval/filters/selection.py
lm_eval/filters/selection.py
+1
-1
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+12
-19
No files found.
lm_eval/api/filter.py
View file @
5b8a7506
...
@@ -3,7 +3,6 @@ from dataclasses import dataclass
...
@@ -3,7 +3,6 @@ from dataclasses import dataclass
from
typing
import
Callable
,
Iterable
,
List
,
Union
from
typing
import
Callable
,
Iterable
,
List
,
Union
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.schemas
import
GenerateOutput
class
Filter
(
ABC
):
class
Filter
(
ABC
):
...
@@ -47,13 +46,13 @@ class FilterEnsemble:
...
@@ -47,13 +46,13 @@ class FilterEnsemble:
resps
,
docs
=
zip
(
*
((
inst
.
resps
,
inst
.
doc
)
for
inst
in
instances
))
resps
,
docs
=
zip
(
*
((
inst
.
resps
,
inst
.
doc
)
for
inst
in
instances
))
# TODO: add backward
# TODO: add backward
# unwrap responses from GenerateOutput as the filters expect strings
# unwrap responses from GenerateOutput as the filters expect strings
resps
=
tuple
(
#
resps = tuple(
[
#
[
item
.
text
if
isinstance
(
item
,
GenerateOutput
)
else
str
(
item
)
#
item.text if isinstance(item, GenerateOutput) else item
for
item
in
sublist
#
for item in sublist
]
#
]
for
sublist
in
resps
#
for sublist in resps
)
#
)
for
f
in
self
.
filters
:
for
f
in
self
.
filters
:
# apply filters in sequence
# apply filters in sequence
...
...
lm_eval/api/instance.py
View file @
5b8a7506
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
Generic
,
Literal
,
Optional
,
Tuple
,
TypeVar
,
Union
from
typing
import
Literal
,
Optional
,
Tuple
from
lm_eval.api.schemas
import
GenerateInput
,
LoglikelihoodInput
# from lm_eval.api.schemas import GenerateInput, LoglikelihoodInput
OutputType
=
Literal
[
OutputType
=
Literal
[
"loglikelihood"
,
"loglikelihood_rolling"
,
"generate_until"
,
"multiple_choice"
"loglikelihood"
,
"loglikelihood_rolling"
,
"generate_until"
,
"multiple_choice"
]
]
T
=
TypeVar
(
"T"
,
LoglikelihoodInput
,
GenerateInput
)
#
T = TypeVar("T", LoglikelihoodInput, GenerateInput)
@
dataclass
@
dataclass
class
Instance
(
Generic
[
T
])
:
class
Instance
:
request_type
:
OutputType
request_type
:
OutputType
doc
:
dict
doc
:
dict
arguments
:
T
arguments
:
tuple
idx
:
int
idx
:
int
metadata
:
Tuple
[
Optional
[
str
],
Optional
[
int
],
Optional
[
int
]]
=
field
(
metadata
:
Tuple
[
Optional
[
str
],
Optional
[
int
],
Optional
[
int
]]
=
field
(
default_factory
=
lambda
:
(
None
,
None
,
None
)
default_factory
=
lambda
:
(
None
,
None
,
None
)
)
)
resps
:
list
[
Union
[
GenerateInput
,
LoglikelihoodInput
]]
=
field
(
default_factory
=
list
)
resps
:
list
=
field
(
default_factory
=
list
)
filtered_resps
:
dict
=
field
(
default_factory
=
dict
)
filtered_resps
:
dict
=
field
(
default_factory
=
dict
)
# initialized after init
# initialized after init
...
@@ -33,7 +34,7 @@ class Instance(Generic[T]):
...
@@ -33,7 +34,7 @@ class Instance(Generic[T]):
self
.
task_name
,
self
.
doc_id
,
self
.
repeats
=
self
.
metadata
self
.
task_name
,
self
.
doc_id
,
self
.
repeats
=
self
.
metadata
@
property
@
property
def
args
(
self
)
->
T
:
def
args
(
self
):
"""
"""
Returns (string,) where `string` is the string to calculate loglikelihood over
Returns (string,) where `string` is the string to calculate loglikelihood over
"""
"""
...
...
lm_eval/api/model.py
View file @
5b8a7506
...
@@ -9,10 +9,12 @@ from tqdm import tqdm
...
@@ -9,10 +9,12 @@ from tqdm import tqdm
from
lm_eval
import
utils
from
lm_eval
import
utils
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.schemas
import
(
LoglikelihoodInput
,
LoglikelihoodOutput
,
# from lm_eval.api.schemas import (
)
# LoglikelihoodInput,
# LoglikelihoodOutput,
# )
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -381,9 +383,7 @@ class TemplateLM(LM):
...
@@ -381,9 +383,7 @@ class TemplateLM(LM):
self
,
requests
:
list
[
"Instance"
],
disable_tqdm
:
bool
=
False
self
,
requests
:
list
[
"Instance"
],
disable_tqdm
:
bool
=
False
)
->
list
[
tuple
[
float
,
bool
]]:
)
->
list
[
tuple
[
float
,
bool
]]:
new_reqs
=
[]
new_reqs
=
[]
for
context
,
continuation
in
(
for
context
,
continuation
in
[
req
.
args
for
req
in
requests
]:
(
req
.
args
.
context
,
req
.
args
.
continuation
)
for
req
in
requests
):
if
context
==
""
:
if
context
==
""
:
# BOS or EOS as context
# BOS or EOS as context
context_enc
,
continuation_enc
=
(
context_enc
,
continuation_enc
=
(
...
...
lm_eval/api/schemas.py
View file @
5b8a7506
...
@@ -2,61 +2,60 @@ from dataclasses import dataclass
...
@@ -2,61 +2,60 @@ from dataclasses import dataclass
from
typing
import
Optional
from
typing
import
Optional
@
dataclass
# @dataclass
class
GenerateInput
:
# class GenerateInput:
"""
# """
Inputs for the generate function.
# Inputs for the generate function.
"""
# """
#
prompt
:
str
# prompt: str
gen_kwargs
:
dict
# gen_kwargs: dict
multimodal_arg
:
Optional
[
dict
]
=
None
# multimodal_arg: Optional[dict] = None
#
def
__iter__
(
self
):
# def __iter__(self):
return
(
# return (
iter
((
self
.
prompt
,
self
.
gen_kwargs
))
# iter((self.prompt, self.gen_kwargs))
if
not
self
.
multimodal_arg
# if not self.multimodal_arg
else
iter
((
self
.
prompt
,
self
.
gen_kwargs
,
self
.
multimodal_arg
))
# else iter((self.prompt, self.gen_kwargs, self.multimodal_arg))
)
# )
#
def
__getitem__
(
self
,
item
:
int
):
# def __getitem__(self, item: int):
return
[
self
.
prompt
,
self
.
gen_kwargs
][
item
]
# return [self.prompt, self.gen_kwargs][item]
#
#
@
dataclass
# @dataclass
class
GenerateOutput
:
# class GenerateOutput:
"""
# """
Outputs for the generate function.
# Outputs for the generate function.
"""
# """
#
text
:
str
# text: str
metadata
:
dict
=
None
# metadata: dict = None
#
#
@
dataclass
# @dataclass
class
LoglikelihoodInput
:
# class LoglikelihoodInput:
"""
# """
Inputs for the loglikelihood function.
# Inputs for the loglikelihood function.
"""
# """
#
context
:
str
# context: str
continuation
:
Optional
[
str
]
=
None
# continuation: Optional[str] = None
#
#
@
dataclass
# class LoglikelihoodOutput(NamedTuple):
class
LoglikelihoodOutput
:
# """
"""
# Outputs for the loglikelihood function.
Outputs for the loglikelihood function.
# """
"""
#
# loglikelihood: float
loglikelihood
:
float
# is_greedy: Optional[bool] = None
is_greedy
:
Optional
[
bool
]
=
None
# ctx_tokens: Optional[list[int]] = None
ctx_tokens
:
Optional
[
list
[
int
]]
=
None
# cont_tokens: Optional[list[int]] = None
cont_tokens
:
Optional
[
list
[
int
]]
=
None
# metadata: Optional[dict] = None
metadata
:
Optional
[
dict
]
=
None
# def __iter__(self):
def
__iter__
(
self
):
# return iter((self.loglikelihood, self.is_greedy))
return
iter
((
self
.
loglikelihood
,
self
.
is_greedy
))
@
dataclass
@
dataclass
...
@@ -66,7 +65,7 @@ class MetricResult:
...
@@ -66,7 +65,7 @@ class MetricResult:
"""
"""
doc_id
:
str
|
int
|
None
doc_id
:
str
|
int
|
None
scores
:
list
[
dict
[
str
,
float
]]
|
None
scores
:
list
[
dict
[
str
,
float
]]
|
dict
filter_key
:
str
=
None
filter_key
:
str
=
None
metric_name
:
str
=
None
metric_name
:
str
=
None
metadata
:
Optional
[
dict
]
=
None
metadata
:
Optional
[
dict
]
=
None
...
@@ -76,6 +75,8 @@ class MetricResult:
...
@@ -76,6 +75,8 @@ class MetricResult:
return
iter
([])
return
iter
([])
# Group values by metric key
# Group values by metric key
if
not
isinstance
(
self
.
scores
,
list
):
self
.
scores
=
[
self
.
scores
]
grouped
=
{}
grouped
=
{}
for
score_dict
in
self
.
scores
:
for
score_dict
in
self
.
scores
:
for
key
,
value
in
score_dict
.
items
():
for
key
,
value
in
score_dict
.
items
():
...
@@ -99,4 +100,8 @@ class MetricResult:
...
@@ -99,4 +100,8 @@ class MetricResult:
def
metric_keys
(
self
)
->
list
[
str
]:
def
metric_keys
(
self
)
->
list
[
str
]:
if
self
.
scores
is
None
:
if
self
.
scores
is
None
:
return
[]
return
[]
return
list
(
self
.
scores
[
0
].
keys
())
if
self
.
scores
else
[]
return
(
list
(
self
.
scores
[
0
].
keys
())
if
isinstance
(
self
.
scores
,
list
)
else
list
(
self
.
scores
.
keys
())
)
lm_eval/api/task.py
View file @
5b8a7506
...
@@ -37,7 +37,7 @@ from lm_eval.api.registry import (
...
@@ -37,7 +37,7 @@ from lm_eval.api.registry import (
get_metric_aggregation
,
get_metric_aggregation
,
is_higher_better
,
is_higher_better
,
)
)
from
lm_eval.api.schemas
import
GenerateInput
,
LoglikelihoodInput
,
MetricResult
from
lm_eval.api.schemas
import
MetricResult
from
lm_eval.caching.cache
import
load_from_cache
,
save_to_cache
from
lm_eval.caching.cache
import
load_from_cache
,
save_to_cache
from
lm_eval.filters
import
build_filter_ensemble
from
lm_eval.filters
import
build_filter_ensemble
from
lm_eval.prompts
import
get_prompt
from
lm_eval.prompts
import
get_prompt
...
@@ -1531,7 +1531,8 @@ class ConfigurableTask(Task):
...
@@ -1531,7 +1531,8 @@ class ConfigurableTask(Task):
Instance
(
Instance
(
request_type
=
"loglikelihood"
,
request_type
=
"loglikelihood"
,
doc
=
doc
,
doc
=
doc
,
arguments
=
LoglikelihoodInput
(
context
=
arg
[
0
],
continuation
=
arg
[
1
]),
arguments
=
arg
,
# arguments=LoglikelihoodInput(context=arg[0], continuation=arg[1]),
idx
=
i
,
idx
=
i
,
**
kwargs
,
**
kwargs
,
)
)
...
@@ -1543,9 +1544,9 @@ class ConfigurableTask(Task):
...
@@ -1543,9 +1544,9 @@ class ConfigurableTask(Task):
return
Instance
(
return
Instance
(
request_type
=
self
.
OUTPUT_TYPE
,
request_type
=
self
.
OUTPUT_TYPE
,
doc
=
doc
,
doc
=
doc
,
arguments
=
LoglikelihoodInput
(
*
arguments
)
arguments
=
arguments
,
if
self
.
OUTPUT_TYPE
in
[
"loglikelihood"
,
"loglikelihood_rolling"
]
#
if self.OUTPUT_TYPE in ["loglikelihood", "loglikelihood_rolling"]
else
GenerateInput
(
*
arguments
),
#
else GenerateInput(*arguments),
idx
=
0
,
idx
=
0
,
**
kwargs
,
**
kwargs
,
)
)
...
@@ -1819,15 +1820,21 @@ class ConfigurableTask(Task):
...
@@ -1819,15 +1820,21 @@ class ConfigurableTask(Task):
for
doc_id
,
doc
in
doc_iterator
:
for
doc_id
,
doc
in
doc_iterator
:
# doc_id_true = indices[doc_id] if indices else doc_id
# doc_id_true = indices[doc_id] if indices else doc_id
requests
=
instances_by_doc_id
[
doc_id
]
requests
=
instances_by_doc_id
[
doc_id
]
metrics
=
[
if
len
(
requests
)
>
1
:
self
.
process_results
(
doc
,
response
)
# if one doc has multiple instances then calculate metric together
for
req
in
requests
metrics
=
self
.
process_results
(
for
response
in
(
doc
,
[
req
.
filtered_resps
[
filter_key
]
for
req
in
requests
]
req
.
filtered_resps
[
filter_key
]
if
isinstance
(
req
.
filtered_resps
[
filter_key
],
list
)
else
[
req
.
filtered_resps
[
filter_key
]]
)
)
]
else
:
metrics
=
[
self
.
process_results
(
doc
,
response
)
for
req
in
requests
for
response
in
(
req
.
filtered_resps
[
filter_key
]
if
isinstance
(
req
.
filtered_resps
[
filter_key
],
list
)
else
[
req
.
filtered_resps
[
filter_key
]]
)
]
all_metrics
[
filter_key
].
append
(
all_metrics
[
filter_key
].
append
(
MetricResult
(
scores
=
metrics
,
doc_id
=
doc_id
,
filter_key
=
filter_key
)
MetricResult
(
scores
=
metrics
,
doc_id
=
doc_id
,
filter_key
=
filter_key
)
)
)
...
...
lm_eval/evaluator.py
View file @
5b8a7506
...
@@ -647,9 +647,7 @@ def evaluate(
...
@@ -647,9 +647,7 @@ def evaluate(
ensure_ascii
=
False
,
ensure_ascii
=
False
,
)
)
),
),
"prompt_hash"
:
hash_string
(
"prompt_hash"
:
hash_string
(
requests
[
0
].
arguments
[
0
]),
requests
[
0
].
arguments
.
prompt
),
"target_hash"
:
hash_string
(
str
(
target
)),
"target_hash"
:
hash_string
(
str
(
target
)),
}
}
example
.
update
(
example
.
update
(
...
...
lm_eval/filters/selection.py
View file @
5b8a7506
...
@@ -20,7 +20,7 @@ class TakeFirstFilter(Filter):
...
@@ -20,7 +20,7 @@ class TakeFirstFilter(Filter):
"""
"""
Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
"""
"""
return
map
(
lambda
r
:
r
,
resps
)
return
map
(
lambda
r
:
r
[
0
]
,
resps
)
@
register_filter
(
"take_first_k"
)
@
register_filter
(
"take_first_k"
)
...
...
lm_eval/models/huggingface.py
View file @
5b8a7506
...
@@ -27,7 +27,6 @@ from lm_eval import utils
...
@@ -27,7 +27,6 @@ from lm_eval import utils
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.model
import
TemplateLM
from
lm_eval.api.model
import
TemplateLM
from
lm_eval.api.registry
import
register_model
from
lm_eval.api.registry
import
register_model
from
lm_eval.api.schemas
import
GenerateInput
,
GenerateOutput
,
LoglikelihoodOutput
from
lm_eval.models.utils
import
(
from
lm_eval.models.utils
import
(
Collator
,
Collator
,
clear_torch_cache
,
clear_torch_cache
,
...
@@ -966,7 +965,7 @@ class HFLM(TemplateLM):
...
@@ -966,7 +965,7 @@ class HFLM(TemplateLM):
def
loglikelihood_rolling
(
def
loglikelihood_rolling
(
self
,
requests
:
List
[
Instance
],
disable_tqdm
:
bool
=
False
self
,
requests
:
List
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
l
ist
[
LoglikelihoodOutpu
t
]:
)
->
L
ist
[
floa
t
]:
adaptive_batch_size
=
None
adaptive_batch_size
=
None
if
self
.
batch_size
==
"auto"
:
if
self
.
batch_size
==
"auto"
:
# using rolling window with maximum context
# using rolling window with maximum context
...
@@ -1026,7 +1025,7 @@ class HFLM(TemplateLM):
...
@@ -1026,7 +1025,7 @@ class HFLM(TemplateLM):
override_bs
=
len
(
batch_windows
),
override_bs
=
len
(
batch_windows
),
)
)
# Store results with their request indices
# Store results with their request indices
all_nlls
.
extend
(
zip
(
batch_indices
,
(
x
.
loglikelihood
for
x
in
batch_nlls
))
)
all_nlls
.
extend
(
zip
(
batch_indices
,
batch_nlls
))
# Remove padding if necessary
# Remove padding if necessary
if
(
self
.
world_size
>
1
)
and
(
pad_amnt
>
0
):
if
(
self
.
world_size
>
1
)
and
(
pad_amnt
>
0
):
...
@@ -1039,8 +1038,8 @@ class HFLM(TemplateLM):
...
@@ -1039,8 +1038,8 @@ class HFLM(TemplateLM):
# Get all nlls for this request
# Get all nlls for this request
request_nlls
=
all_nlls
[
current_idx
:
current_idx
+
window_count
]
request_nlls
=
all_nlls
[
current_idx
:
current_idx
+
window_count
]
# Sum up the nlls for this request (discarding is_greedy)
# Sum up the nlls for this request (discarding is_greedy)
request_total
=
sum
(
nll
for
nll
in
request_nlls
)
request_total
=
sum
(
nll
[
0
]
for
_
,
nll
in
request_nlls
)
loglikelihoods
.
append
(
LoglikelihoodOutput
(
loglikelihood
=
request_total
)
)
loglikelihoods
.
append
(
request_total
)
current_idx
+=
window_count
current_idx
+=
window_count
string
=
requests
[
len
(
loglikelihoods
)
-
1
].
args
[
0
]
string
=
requests
[
len
(
loglikelihoods
)
-
1
].
args
[
0
]
...
@@ -1072,7 +1071,7 @@ class HFLM(TemplateLM):
...
@@ -1072,7 +1071,7 @@ class HFLM(TemplateLM):
requests
:
List
[
Tuple
[
Tuple
[
str
,
str
],
List
[
int
],
List
[
int
]]],
requests
:
List
[
Tuple
[
Tuple
[
str
,
str
],
List
[
int
],
List
[
int
]]],
disable_tqdm
:
bool
=
False
,
disable_tqdm
:
bool
=
False
,
override_bs
:
int
=
None
,
override_bs
:
int
=
None
,
)
->
List
[
LoglikelihoodOutput
]:
)
->
List
[
Tuple
[
float
,
bool
]
]:
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res
=
[]
res
=
[]
...
@@ -1287,13 +1286,7 @@ class HFLM(TemplateLM):
...
@@ -1287,13 +1286,7 @@ class HFLM(TemplateLM):
# Answer: (log prob, is-exact-match)
# Answer: (log prob, is-exact-match)
answer
=
(
float
(
logits
.
sum
()),
bool
(
max_equal
))
answer
=
(
float
(
logits
.
sum
()),
bool
(
max_equal
))
res
.
append
(
res
.
append
(
answer
)
LoglikelihoodOutput
(
*
answer
,
ctx_tokens
=
ctx_tokens
,
cont_tokens
=
cont_toks
.
tolist
(),
)
)
if
request_str
is
not
None
:
if
request_str
is
not
None
:
# special case: loglikelihood_rolling produces a number of loglikelihood requests
# special case: loglikelihood_rolling produces a number of loglikelihood requests
...
@@ -1309,8 +1302,8 @@ class HFLM(TemplateLM):
...
@@ -1309,8 +1302,8 @@ class HFLM(TemplateLM):
return
re_ord
.
get_original
(
res
)
return
re_ord
.
get_original
(
res
)
def
generate_until
(
def
generate_until
(
self
,
requests
:
List
[
Instance
[
GenerateInput
]
],
disable_tqdm
:
bool
=
False
self
,
requests
:
List
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
List
[
GenerateOutput
]:
)
->
List
[
str
]:
res
=
[]
res
=
[]
def
_collate
(
req
:
Tuple
[
str
,
dict
]):
def
_collate
(
req
:
Tuple
[
str
,
dict
]):
...
@@ -1321,8 +1314,8 @@ class HFLM(TemplateLM):
...
@@ -1321,8 +1314,8 @@ class HFLM(TemplateLM):
# padded context length. this is useful to simplify the batching logic and more importantly to make
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
# - any OOMs will happen right away rather than near the end
toks
=
self
.
tok_encode
(
req
.
prompt
)
toks
=
self
.
tok_encode
(
req
[
0
]
)
return
-
len
(
toks
),
req
.
prompt
return
-
len
(
toks
),
req
[
0
]
pbar
=
tqdm
(
pbar
=
tqdm
(
total
=
len
(
requests
),
total
=
len
(
requests
),
...
@@ -1358,7 +1351,7 @@ class HFLM(TemplateLM):
...
@@ -1358,7 +1351,7 @@ class HFLM(TemplateLM):
[
reg
.
args
for
reg
in
requests
],
[
reg
.
args
for
reg
in
requests
],
sort_fn
=
_collate
,
sort_fn
=
_collate
,
group_by
=
"gen_kwargs"
,
group_by
=
"gen_kwargs"
,
group_fn
=
lambda
x
:
x
.
gen_kwargs
,
group_fn
=
lambda
x
:
x
[
1
]
,
)
)
chunks
=
re_ords
.
get_batched
(
n
=
batch_size
,
batch_fn
=
batch_fn
)
chunks
=
re_ords
.
get_batched
(
n
=
batch_size
,
batch_fn
=
batch_fn
)
eos
=
self
.
tok_decode
(
self
.
eot_token_id
,
skip_special_tokens
=
False
)
eos
=
self
.
tok_decode
(
self
.
eot_token_id
,
skip_special_tokens
=
False
)
...
@@ -1427,7 +1420,7 @@ class HFLM(TemplateLM):
...
@@ -1427,7 +1420,7 @@ class HFLM(TemplateLM):
# for seq2seq case where self.tok_decode(self.eot_token_id) = ''
# for seq2seq case where self.tok_decode(self.eot_token_id) = ''
s
=
s
.
split
(
term
)[
0
]
s
=
s
.
split
(
term
)[
0
]
res
.
append
(
GenerateOutput
(
text
=
s
)
)
res
.
append
(
s
)
self
.
cache_hook
.
add_partial
(
"generate_until"
,
(
context
,
gen_kwargs
),
s
)
self
.
cache_hook
.
add_partial
(
"generate_until"
,
(
context
,
gen_kwargs
),
s
)
pbar
.
update
(
1
)
pbar
.
update
(
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment