Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
6d63c2ce
Commit
6d63c2ce
authored
Jul 23, 2025
by
Baber
Browse files
types
parent
0087929e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
180 additions
and
181 deletions
+180
-181
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+138
-141
lm_eval/models/vllm_causallms.py
lm_eval/models/vllm_causallms.py
+42
-40
No files found.
lm_eval/models/huggingface.py
View file @
6d63c2ce
from
__future__
import
annotations
import
copy
import
copy
import
logging
import
logging
import
os
import
os
from
datetime
import
timedelta
from
datetime
import
timedelta
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Literal
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Literal
import
jinja2
import
jinja2
import
torch
import
torch
...
@@ -40,7 +42,7 @@ from lm_eval.models.utils import (
...
@@ -40,7 +42,7 @@ from lm_eval.models.utils import (
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
transformers.quantizers
import
AutoQuantizationConfig
from
transformers.quantizers
.auto
import
AutoQuantizationConfig
eval_logger
=
logging
.
getLogger
(
__name__
)
eval_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -59,46 +61,43 @@ class HFLM(TemplateLM):
...
@@ -59,46 +61,43 @@ class HFLM(TemplateLM):
def
__init__
(
def
__init__
(
self
,
self
,
pretrained
:
Union
[
str
,
transformers
.
PreTrainedModel
]
,
pretrained
:
str
|
transformers
.
PreTrainedModel
,
backend
:
Literal
[
"default"
,
"causal"
,
"seq2seq"
]
=
"default"
,
backend
:
Literal
[
"default"
,
"causal"
,
"seq2seq"
]
=
"default"
,
# override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
# override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
revision
:
Optional
[
str
]
=
"main"
,
revision
:
str
|
None
=
"main"
,
subfolder
:
str
=
""
,
subfolder
:
str
=
""
,
tokenizer
:
Optional
[
tokenizer
:
str
Union
[
|
transformers
.
PreTrainedTokenizer
str
,
|
transformers
.
PreTrainedTokenizerFast
transformers
.
PreTrainedTokenizer
,
|
None
=
None
,
transformers
.
PreTrainedTokenizerFast
,
truncation
:
bool
|
None
=
False
,
]
]
=
None
,
truncation
:
Optional
[
bool
]
=
False
,
logits_cache
:
bool
=
True
,
logits_cache
:
bool
=
True
,
max_length
:
Optional
[
int
]
=
None
,
max_length
:
int
|
None
=
None
,
device
:
Optional
[
str
]
=
"cuda"
,
device
:
str
|
None
=
"cuda"
,
dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
"auto"
,
dtype
:
str
|
torch
.
dtype
|
None
=
"auto"
,
softmax_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
,
softmax_dtype
:
str
|
torch
.
dtype
|
None
=
None
,
mixed_precision_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
,
mixed_precision_dtype
:
str
|
torch
.
dtype
|
None
=
None
,
batch_size
:
Optional
[
Union
[
int
,
str
]]
=
1
,
batch_size
:
int
|
str
|
None
=
1
,
max_batch_size
:
Optional
[
int
]
=
64
,
max_batch_size
:
int
|
None
=
64
,
trust_remote_code
:
Optional
[
bool
]
=
False
,
trust_remote_code
:
bool
|
None
=
False
,
use_fast_tokenizer
:
Optional
[
bool
]
=
True
,
use_fast_tokenizer
:
bool
|
None
=
True
,
add_bos_token
:
Optional
[
bool
]
=
False
,
add_bos_token
:
bool
|
None
=
False
,
prefix_token_id
:
Optional
[
int
]
=
None
,
prefix_token_id
:
int
|
None
=
None
,
# arguments used for splitting a model across GPUs naively.
# arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`.
# only used if `parallelize=True`.
parallelize
:
Optional
[
bool
]
=
False
,
parallelize
:
bool
|
None
=
False
,
max_memory_per_gpu
:
Optional
[
Union
[
int
,
str
]]
=
None
,
max_memory_per_gpu
:
int
|
str
|
None
=
None
,
max_cpu_memory
:
Optional
[
Union
[
int
,
str
]]
=
None
,
max_cpu_memory
:
int
|
str
|
None
=
None
,
offload_folder
:
Optional
[
Union
[
str
,
os
.
PathLike
]]
=
"./offload"
,
offload_folder
:
str
|
os
.
PathLike
|
None
=
"./offload"
,
# PEFT, delta weights and quantization options
# PEFT, delta weights and quantization options
peft
:
Optional
[
str
]
=
None
,
peft
:
str
|
None
=
None
,
delta
:
Optional
[
str
]
=
None
,
delta
:
str
|
None
=
None
,
autogptq
:
Optional
[
Union
[
bool
,
str
]]
=
False
,
autogptq
:
bool
|
str
|
None
=
False
,
gptqmodel
:
Optional
[
bool
]
=
False
,
gptqmodel
:
bool
|
None
=
False
,
gguf_file
:
Optional
[
str
]
=
None
,
gguf_file
:
str
|
None
=
None
,
# end token for thinking, either the string or int token id.
# end token for thinking, either the string or int token id.
# splits to get response after this token (if provided).
# splits to get response after this token (if provided).
think_end_token
:
Union
[
str
,
int
,
None
]
=
None
,
think_end_token
:
str
|
int
|
None
=
None
,
**
kwargs
,
**
kwargs
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -271,9 +270,10 @@ class HFLM(TemplateLM):
...
@@ -271,9 +270,10 @@ class HFLM(TemplateLM):
self
.
batch_size_per_gpu
=
int
(
batch_size
)
self
.
batch_size_per_gpu
=
int
(
batch_size
)
if
isinstance
(
pretrained
,
str
):
if
isinstance
(
pretrained
,
str
):
if
gpus
>=
1
or
str
(
self
.
device
)
==
"mps"
:
if
(
gpus
>=
1
or
str
(
self
.
device
)
==
"mps"
)
and
not
(
parallelize
or
autogptq
or
hasattr
(
self
,
"accelerator"
)
):
# TODO: can remove this whole snippet except in the mps case, perhaps?
# TODO: can remove this whole snippet except in the mps case, perhaps?
if
not
(
parallelize
or
autogptq
or
hasattr
(
self
,
"accelerator"
)):
# place model onto device requested manually,
# place model onto device requested manually,
# if not using HF Accelerate or device_map
# if not using HF Accelerate or device_map
# or any other option that preloads model onto device
# or any other option that preloads model onto device
...
@@ -327,12 +327,12 @@ class HFLM(TemplateLM):
...
@@ -327,12 +327,12 @@ class HFLM(TemplateLM):
def
_get_accelerate_args
(
def
_get_accelerate_args
(
self
,
self
,
parallelize
:
Optional
[
bool
]
=
None
,
parallelize
:
bool
|
None
=
None
,
device_map
:
Optional
[
str
]
=
"auto"
,
device_map
:
str
|
None
=
"auto"
,
max_memory_per_gpu
:
Optional
[
Union
[
int
,
str
]]
=
None
,
max_memory_per_gpu
:
int
|
str
|
None
=
None
,
max_cpu_memory
:
Optional
[
Union
[
int
,
str
]]
=
None
,
max_cpu_memory
:
int
|
str
|
None
=
None
,
offload_folder
:
Optional
[
str
]
=
"./offload"
,
offload_folder
:
str
|
None
=
"./offload"
,
gpus
:
Optional
[
int
]
=
None
,
gpus
:
int
|
None
=
None
,
)
->
dict
:
)
->
dict
:
"""Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
"""Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
num_local_processes
=
int
(
os
.
environ
.
get
(
"LOCAL_WORLD_SIZE"
,
1
))
num_local_processes
=
int
(
os
.
environ
.
get
(
"LOCAL_WORLD_SIZE"
,
1
))
...
@@ -480,9 +480,9 @@ class HFLM(TemplateLM):
...
@@ -480,9 +480,9 @@ class HFLM(TemplateLM):
def
_get_backend
(
def
_get_backend
(
self
,
self
,
config
:
Union
[
transformers
.
PretrainedConfig
,
transformers
.
AutoConfig
]
,
config
:
transformers
.
PretrainedConfig
|
transformers
.
AutoConfig
,
backend
:
Literal
[
"default"
,
"causal"
,
"seq2seq"
]
=
"default"
,
backend
:
Literal
[
"default"
,
"causal"
,
"seq2seq"
]
=
"default"
,
trust_remote_code
:
Optional
[
bool
]
=
False
,
trust_remote_code
:
bool
|
None
=
False
,
)
->
None
:
)
->
None
:
"""
"""
Helper method during initialization.
Helper method during initialization.
...
@@ -497,27 +497,20 @@ class HFLM(TemplateLM):
...
@@ -497,27 +497,20 @@ class HFLM(TemplateLM):
if
backend
!=
"default"
:
if
backend
!=
"default"
:
# if we've settled on non-default backend, use that manually
# if we've settled on non-default backend, use that manually
if
backend
==
"causal"
:
if
backend
in
[
"causal"
,
"seq2seq"
]:
self
.
backend
=
backend
elif
backend
==
"seq2seq"
:
self
.
backend
=
backend
self
.
backend
=
backend
eval_logger
.
info
(
eval_logger
.
info
(
f
"Overrode HF model backend type, and using type '
{
self
.
backend
}
'"
f
"Overrode HF model backend type, and using type '
{
self
.
backend
}
'"
)
)
else
:
else
:
# determine and use the default HF backend for this model, based on its config + metadata.
# determine and use the default HF backend for this model, based on its config + metadata.
if
(
if
self
.
config
.
model_type
in
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
:
getattr
(
config
,
"model_type"
)
in
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
):
# first check if model type is listed under seq2seq models, since some
# first check if model type is listed under seq2seq models, since some
# models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
# models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
# these special cases should be treated as seq2seq models.
# these special cases should be treated as seq2seq models.
self
.
backend
=
"seq2seq"
self
.
backend
=
"seq2seq"
eval_logger
.
debug
(
f
"Using model type '
{
self
.
backend
}
'"
)
eval_logger
.
debug
(
f
"Using model type '
{
self
.
backend
}
'"
)
elif
(
elif
self
.
config
.
model_type
in
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
:
getattr
(
self
.
config
,
"model_type"
)
in
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
):
self
.
backend
=
"causal"
self
.
backend
=
"causal"
eval_logger
.
debug
(
f
"Using model type '
{
self
.
backend
}
'"
)
eval_logger
.
debug
(
f
"Using model type '
{
self
.
backend
}
'"
)
else
:
else
:
...
@@ -545,7 +538,7 @@ class HFLM(TemplateLM):
...
@@ -545,7 +538,7 @@ class HFLM(TemplateLM):
pretrained
:
str
,
pretrained
:
str
,
revision
:
str
=
"main"
,
revision
:
str
=
"main"
,
trust_remote_code
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
gguf_file
:
Optional
[
str
]
=
None
,
gguf_file
:
str
|
None
=
None
,
subfolder
:
str
=
""
,
subfolder
:
str
=
""
,
)
->
None
:
)
->
None
:
"""Return the model config for HuggingFace models"""
"""Return the model config for HuggingFace models"""
...
@@ -560,24 +553,24 @@ class HFLM(TemplateLM):
...
@@ -560,24 +553,24 @@ class HFLM(TemplateLM):
def
_create_model
(
def
_create_model
(
self
,
self
,
pretrained
:
str
,
pretrained
:
str
,
revision
:
Optional
[
str
]
=
"main"
,
revision
:
str
|
None
=
"main"
,
dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
"auto"
,
dtype
:
str
|
torch
.
dtype
|
None
=
"auto"
,
trust_remote_code
:
Optional
[
bool
]
=
False
,
trust_remote_code
:
bool
|
None
=
False
,
# arguments used for splitting a model across GPUs naively.
# arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`.
# only used if `parallelize=True`.
# (accelerate naive PP (device_map) options)
# (accelerate naive PP (device_map) options)
parallelize
:
Optional
[
bool
]
=
False
,
parallelize
:
bool
|
None
=
False
,
gpus
:
Optional
[
int
]
=
None
,
gpus
:
int
|
None
=
None
,
max_memory_per_gpu
:
Optional
[
Union
[
int
,
str
]]
=
None
,
max_memory_per_gpu
:
int
|
str
|
None
=
None
,
max_cpu_memory
:
Optional
[
Union
[
int
,
str
]]
=
None
,
max_cpu_memory
:
int
|
str
|
None
=
None
,
offload_folder
:
Optional
[
str
]
=
"./offload"
,
offload_folder
:
str
|
None
=
"./offload"
,
# PEFT, delta weights and quantization options
# PEFT, delta weights and quantization options
peft
:
Optional
[
str
]
=
None
,
peft
:
str
|
None
=
None
,
delta
:
Optional
[
str
]
=
None
,
delta
:
str
|
None
=
None
,
autogptq
:
Optional
[
Union
[
bool
,
str
]]
=
False
,
autogptq
:
bool
|
str
|
None
=
False
,
gptqmodel
:
Optional
[
bool
]
=
False
,
gptqmodel
:
bool
|
None
=
False
,
gguf_file
:
Optional
[
str
]
=
None
,
gguf_file
:
str
|
None
=
None
,
quantization_config
:
Optional
[
"
AutoQuantizationConfig
"
]
=
None
,
quantization_config
:
AutoQuantizationConfig
|
None
=
None
,
subfolder
:
str
=
""
,
subfolder
:
str
=
""
,
**
kwargs
,
**
kwargs
,
)
->
None
:
)
->
None
:
...
@@ -598,7 +591,7 @@ class HFLM(TemplateLM):
...
@@ -598,7 +591,7 @@ class HFLM(TemplateLM):
model_kwargs
.
update
(
model_kwargs
.
update
(
self
.
_get_accelerate_args
(
self
.
_get_accelerate_args
(
parallelize
=
parallelize
,
parallelize
=
parallelize
,
device_map
=
kwargs
.
get
(
"device_map"
,
None
),
device_map
=
kwargs
.
get
(
"device_map"
),
max_memory_per_gpu
=
max_memory_per_gpu
,
max_memory_per_gpu
=
max_memory_per_gpu
,
max_cpu_memory
=
max_cpu_memory
,
max_cpu_memory
=
max_cpu_memory
,
offload_folder
=
offload_folder
,
offload_folder
=
offload_folder
,
...
@@ -611,12 +604,11 @@ class HFLM(TemplateLM):
...
@@ -611,12 +604,11 @@ class HFLM(TemplateLM):
assert
transformers
.
__version__
>=
"4.30.0"
,
(
assert
transformers
.
__version__
>=
"4.30.0"
,
(
"load_in_4bit requires transformers >= 4.30.0"
"load_in_4bit requires transformers >= 4.30.0"
)
)
if
transformers
.
__version__
>=
"4.30.0"
:
if
transformers
.
__version__
>=
"4.30.0"
and
(
if
model_kwargs
.
get
(
"load_in_4bit"
,
None
):
model_kwargs
.
get
(
"load_in_4bit"
)
if
model_kwargs
.
get
(
"bnb_4bit_compute_dtype"
,
None
):
and
(
compute_dtype
:
=
model_kwargs
.
get
(
"bnb_4bit_compute_dtype"
))
model_kwargs
[
"bnb_4bit_compute_dtype"
]
=
get_dtype
(
):
model_kwargs
[
"bnb_4bit_compute_dtype"
]
model_kwargs
[
"bnb_4bit_compute_dtype"
]
=
get_dtype
(
compute_dtype
)
)
self
.
_model
=
self
.
AUTO_MODEL_CLASS
.
from_pretrained
(
self
.
_model
=
self
.
AUTO_MODEL_CLASS
.
from_pretrained
(
pretrained
,
pretrained
,
...
@@ -641,7 +633,7 @@ class HFLM(TemplateLM):
...
@@ -641,7 +633,7 @@ class HFLM(TemplateLM):
raise
type
(
exception
)(
raise
type
(
exception
)(
"Tried to load auto_gptq, but auto-gptq is not installed "
,
"Tried to load auto_gptq, but auto-gptq is not installed "
,
"please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]"
,
"please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]"
,
)
)
from
exception
self
.
_model
=
AutoGPTQForCausalLM
.
from_quantized
(
self
.
_model
=
AutoGPTQForCausalLM
.
from_quantized
(
pretrained
,
pretrained
,
...
@@ -660,7 +652,7 @@ class HFLM(TemplateLM):
...
@@ -660,7 +652,7 @@ class HFLM(TemplateLM):
raise
type
(
exception
)(
raise
type
(
exception
)(
"Tried to load gptqmodel, but gptqmodel is not installed "
,
"Tried to load gptqmodel, but gptqmodel is not installed "
,
"please install gptqmodel via `pip install gptqmodel --no-build-isolation` or `pip install lm-eval[gptqmodel] --no-build-isolation`"
,
"please install gptqmodel via `pip install gptqmodel --no-build-isolation` or `pip install lm-eval[gptqmodel] --no-build-isolation`"
,
)
)
from
exception
self
.
_model
=
GPTQModel
.
from_quantized
(
self
.
_model
=
GPTQModel
.
from_quantized
(
pretrained
,
trust_remote_code
=
trust_remote_code
,
**
model_kwargs
pretrained
,
trust_remote_code
=
trust_remote_code
,
**
model_kwargs
...
@@ -672,11 +664,11 @@ class HFLM(TemplateLM):
...
@@ -672,11 +664,11 @@ class HFLM(TemplateLM):
)
)
if
peft
:
if
peft
:
from
peft
import
PeftModel
from
peft
import
PeftModel
,
__version__
as
PEFT_VERSION
from
peft
import
__version__
as
PEFT_VERSION
if
model_kwargs
.
get
(
"load_in_4bit"
,
None
):
if
model_kwargs
.
get
(
"load_in_4bit"
)
and
version
.
parse
(
if
version
.
parse
(
PEFT_VERSION
)
<
version
.
parse
(
"0.4.0"
):
PEFT_VERSION
)
<
version
.
parse
(
"0.4.0"
):
raise
AssertionError
(
"load_in_4bit requires peft >= 0.4.0"
)
raise
AssertionError
(
"load_in_4bit requires peft >= 0.4.0"
)
if
self
.
_model
.
config
.
vocab_size
!=
len
(
self
.
tokenizer
):
if
self
.
_model
.
config
.
vocab_size
!=
len
(
self
.
tokenizer
):
# resize model for LoRAs with added tokens
# resize model for LoRAs with added tokens
...
@@ -703,11 +695,13 @@ class HFLM(TemplateLM):
...
@@ -703,11 +695,13 @@ class HFLM(TemplateLM):
try
:
try
:
param
.
data
+=
_model_delta
.
state_dict
()[
name
]
param
.
data
+=
_model_delta
.
state_dict
()[
name
]
except
KeyError
:
except
KeyError
:
raise
KeyError
(
f
"Delta model is missing weights for layer:
{
name
}
"
)
raise
KeyError
(
f
"Delta model is missing weights for layer:
{
name
}
"
)
from
None
except
Exception
as
e
:
except
Exception
as
e
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Failed to add delta weights to layer
{
name
}
. Error:
{
e
}
"
f
"Failed to add delta weights to layer
{
name
}
. Error:
{
e
}
"
)
)
from
e
del
_model_delta
del
_model_delta
...
@@ -715,20 +709,17 @@ class HFLM(TemplateLM):
...
@@ -715,20 +709,17 @@ class HFLM(TemplateLM):
def
_create_tokenizer
(
def
_create_tokenizer
(
self
,
self
,
pretrained
:
Union
[
str
,
transformers
.
PreTrainedModel
],
pretrained
:
str
|
transformers
.
PreTrainedModel
,
tokenizer
:
Optional
[
tokenizer
:
str
Union
[
|
transformers
.
PreTrainedTokenizer
str
,
|
transformers
.
PreTrainedTokenizerFast
transformers
.
PreTrainedTokenizer
,
|
None
,
transformers
.
PreTrainedTokenizerFast
,
revision
:
str
|
None
=
"main"
,
]
trust_remote_code
:
bool
|
None
=
False
,
],
use_fast_tokenizer
:
bool
|
None
=
True
,
revision
:
Optional
[
str
]
=
"main"
,
gguf_file
:
str
|
None
=
None
,
trust_remote_code
:
Optional
[
bool
]
=
False
,
add_bos_token
:
bool
|
None
=
False
,
use_fast_tokenizer
:
Optional
[
bool
]
=
True
,
subfolder
:
str
|
None
=
""
,
gguf_file
:
Optional
[
str
]
=
None
,
add_bos_token
:
Optional
[
bool
]
=
False
,
subfolder
:
Optional
[
str
]
=
""
,
)
->
None
:
)
->
None
:
"""
"""
Helper method during initialization.
Helper method during initialization.
...
@@ -760,8 +751,12 @@ class HFLM(TemplateLM):
...
@@ -760,8 +751,12 @@ class HFLM(TemplateLM):
)
)
else
:
else
:
assert
isinstance
(
assert
isinstance
(
tokenizer
,
transformers
.
PreTrainedTokenizer
tokenizer
,
)
or
isinstance
(
tokenizer
,
transformers
.
PreTrainedTokenizerFast
)
(
transformers
.
PreTrainedTokenizer
,
transformers
.
PreTrainedTokenizerFast
,
),
)
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
else
:
else
:
# Get tokenizer based on 'pretrained'
# Get tokenizer based on 'pretrained'
...
@@ -838,7 +833,7 @@ class HFLM(TemplateLM):
...
@@ -838,7 +833,7 @@ class HFLM(TemplateLM):
def
tok_encode
(
def
tok_encode
(
self
,
string
:
str
,
left_truncate_len
=
None
,
add_special_tokens
=
None
self
,
string
:
str
,
left_truncate_len
=
None
,
add_special_tokens
=
None
)
->
L
ist
[
int
]:
)
->
l
ist
[
int
]:
""" """
""" """
# default for None - empty dict, use predefined tokenizer param
# default for None - empty dict, use predefined tokenizer param
# used for all models except for CausalLM or predefined value
# used for all models except for CausalLM or predefined value
...
@@ -864,11 +859,11 @@ class HFLM(TemplateLM):
...
@@ -864,11 +859,11 @@ class HFLM(TemplateLM):
def
tok_batch_encode
(
def
tok_batch_encode
(
self
,
self
,
strings
:
L
ist
[
str
],
strings
:
l
ist
[
str
],
padding_side
:
str
=
"left"
,
padding_side
:
str
=
"left"
,
left_truncate_len
:
int
=
None
,
left_truncate_len
:
int
=
None
,
truncation
:
bool
=
False
,
truncation
:
bool
=
False
,
)
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
old_padding_side
=
self
.
tokenizer
.
padding_side
old_padding_side
=
self
.
tokenizer
.
padding_side
self
.
tokenizer
.
padding_side
=
padding_side
self
.
tokenizer
.
padding_side
=
padding_side
...
@@ -917,15 +912,17 @@ class HFLM(TemplateLM):
...
@@ -917,15 +912,17 @@ class HFLM(TemplateLM):
A torch tensor of shape [batch, sequence, vocab] with the
A torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model's decoder
logits returned from the model's decoder
"""
"""
with
torch
.
no_grad
():
with
(
with
torch
.
autocast
(
torch
.
no_grad
(),
torch
.
autocast
(
device_type
=
self
.
device
.
type
,
device_type
=
self
.
device
.
type
,
dtype
=
self
.
mixed_precision_dtype
,
dtype
=
self
.
mixed_precision_dtype
,
enabled
=
self
.
mixed_precision_dtype
is
not
None
,
enabled
=
self
.
mixed_precision_dtype
is
not
None
,
),
):
):
if
attn_mask
is
not
None
or
labels
is
not
None
:
if
attn_mask
is
not
None
or
labels
is
not
None
:
assert
attn_mask
is
not
None
and
labels
is
not
None
assert
attn_mask
is
not
None
and
labels
is
not
None
assert
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
assert
transformers
.
AutoModelForSeq2SeqLM
==
self
.
AUTO_MODEL_CLASS
return
self
.
model
(
return
self
.
model
(
input_ids
=
inps
,
attention_mask
=
attn_mask
,
labels
=
labels
input_ids
=
inps
,
attention_mask
=
attn_mask
,
labels
=
labels
).
logits
).
logits
...
@@ -942,7 +939,7 @@ class HFLM(TemplateLM):
...
@@ -942,7 +939,7 @@ class HFLM(TemplateLM):
# remove temperature, as do_sample=False takes care of this
# remove temperature, as do_sample=False takes care of this
# and we don't want a warning from HF
# and we don't want a warning from HF
generation_kwargs
[
"temperature"
]
=
generation_kwargs
.
get
(
"temperature"
,
0.0
)
generation_kwargs
[
"temperature"
]
=
generation_kwargs
.
get
(
"temperature"
,
0.0
)
do_sample
=
generation_kwargs
.
get
(
"do_sample"
,
None
)
do_sample
=
generation_kwargs
.
get
(
"do_sample"
)
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
if
generation_kwargs
.
get
(
"temperature"
)
==
0.0
and
do_sample
is
None
:
if
generation_kwargs
.
get
(
"temperature"
)
==
0.0
and
do_sample
is
None
:
...
@@ -989,8 +986,8 @@ class HFLM(TemplateLM):
...
@@ -989,8 +986,8 @@ class HFLM(TemplateLM):
return
logits
return
logits
def
loglikelihood_rolling
(
def
loglikelihood_rolling
(
self
,
requests
:
L
ist
[
Instance
],
disable_tqdm
:
bool
=
False
self
,
requests
:
l
ist
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
L
ist
[
float
]:
)
->
l
ist
[
float
]:
adaptive_batch_size
=
None
adaptive_batch_size
=
None
if
self
.
batch_size
==
"auto"
:
if
self
.
batch_size
==
"auto"
:
# using rolling window with maximum context
# using rolling window with maximum context
...
@@ -1009,7 +1006,7 @@ class HFLM(TemplateLM):
...
@@ -1009,7 +1006,7 @@ class HFLM(TemplateLM):
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
)),
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
)),
)
)
):
):
rolling_token_windows
:
L
ist
[
T
uple
[
L
ist
[
int
],
L
ist
[
int
]]]
=
list
(
rolling_token_windows
:
l
ist
[
t
uple
[
l
ist
[
int
],
l
ist
[
int
]]]
=
list
(
map
(
map
(
utils
.
make_disjoint_window
,
utils
.
make_disjoint_window
,
utils
.
get_rolling_token_windows
(
utils
.
get_rolling_token_windows
(
...
@@ -1093,14 +1090,14 @@ class HFLM(TemplateLM):
...
@@ -1093,14 +1090,14 @@ class HFLM(TemplateLM):
def
_loglikelihood_tokens
(
def
_loglikelihood_tokens
(
self
,
self
,
requests
:
L
ist
[
T
uple
[
T
uple
[
str
,
str
],
L
ist
[
int
],
L
ist
[
int
]]],
requests
:
l
ist
[
t
uple
[
t
uple
[
str
,
str
],
l
ist
[
int
],
l
ist
[
int
]]],
disable_tqdm
:
bool
=
False
,
disable_tqdm
:
bool
=
False
,
override_bs
:
int
=
None
,
override_bs
:
int
=
None
,
)
->
L
ist
[
T
uple
[
float
,
bool
]]:
)
->
l
ist
[
t
uple
[
float
,
bool
]]:
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res
=
[]
res
=
[]
def
_collate
(
req
:
T
uple
[
T
uple
[
str
,
str
],
L
ist
[
int
],
L
ist
[
int
]]):
def
_collate
(
req
:
t
uple
[
t
uple
[
str
,
str
],
l
ist
[
int
],
l
ist
[
int
]]):
"""Defines the key for the sorted method"""
"""Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages:
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - time estimates will always be over not underestimates, which is more useful for planning
...
@@ -1112,7 +1109,7 @@ class HFLM(TemplateLM):
...
@@ -1112,7 +1109,7 @@ class HFLM(TemplateLM):
toks
=
req
[
1
]
+
req
[
2
]
toks
=
req
[
1
]
+
req
[
2
]
return
-
len
(
toks
),
tuple
(
toks
)
return
-
len
(
toks
),
tuple
(
toks
)
def
_lookup_one_token_cont
(
req
:
T
uple
[
T
uple
[
str
,
str
],
L
ist
[
int
],
L
ist
[
int
]]):
def
_lookup_one_token_cont
(
req
:
t
uple
[
t
uple
[
str
,
str
],
l
ist
[
int
],
l
ist
[
int
]]):
"""Defines the key to group and lookup one-token continuations"""
"""Defines the key to group and lookup one-token continuations"""
# Use with group_by="contexts" (optional)"
# Use with group_by="contexts" (optional)"
# allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
# allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
...
@@ -1286,7 +1283,7 @@ class HFLM(TemplateLM):
...
@@ -1286,7 +1283,7 @@ class HFLM(TemplateLM):
# original args. Otherwise, expands the logits batch dimension and yields each
# original args. Otherwise, expands the logits batch dimension and yields each
# batch along with matching continuation tokens and prompt strings.
# batch along with matching continuation tokens and prompt strings.
# logits -> [1, seq, vocab]
# logits -> [1, seq, vocab]
for
request_str
,
cont_toks
,
logits
in
re_ord
.
get_cache
(
for
request_str
,
cont_toks
,
logits
in
re_ord
.
get_cache
(
# noqa
req_str
=
request_str
,
req_str
=
request_str
,
cxt_toks
=
ctx_tokens
,
cxt_toks
=
ctx_tokens
,
cont_toks
=
cont_toks
,
cont_toks
=
cont_toks
,
...
@@ -1327,11 +1324,11 @@ class HFLM(TemplateLM):
...
@@ -1327,11 +1324,11 @@ class HFLM(TemplateLM):
return
re_ord
.
get_original
(
res
)
return
re_ord
.
get_original
(
res
)
def
generate_until
(
def
generate_until
(
self
,
requests
:
L
ist
[
Instance
],
disable_tqdm
:
bool
=
False
self
,
requests
:
l
ist
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
L
ist
[
str
]:
)
->
l
ist
[
str
]:
res
=
[]
res
=
[]
def
_collate
(
req
:
T
uple
[
str
,
dict
]):
def
_collate
(
req
:
t
uple
[
str
,
dict
]):
"""Defines the key for the sorted method"""
"""Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages:
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - time estimates will always be over not underestimates, which is more useful for planning
...
@@ -1394,7 +1391,7 @@ class HFLM(TemplateLM):
...
@@ -1394,7 +1391,7 @@ class HFLM(TemplateLM):
raise
ValueError
(
raise
ValueError
(
f
"Expected `kwargs` to be of type `dict` but got
{
type
(
gen_kwargs
)
}
"
f
"Expected `kwargs` to be of type `dict` but got
{
type
(
gen_kwargs
)
}
"
)
)
if
"max_gen_toks"
in
kwargs
.
keys
()
:
if
"max_gen_toks"
in
kwargs
:
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
else
:
else
:
max_gen_toks
=
self
.
max_gen_toks
max_gen_toks
=
self
.
max_gen_toks
...
@@ -1472,7 +1469,7 @@ class HFLM(TemplateLM):
...
@@ -1472,7 +1469,7 @@ class HFLM(TemplateLM):
return
res
return
res
def
apply_chat_template
(
def
apply_chat_template
(
self
,
chat_history
:
L
ist
[
D
ict
[
str
,
str
]],
add_generation_prompt
:
bool
=
True
self
,
chat_history
:
l
ist
[
d
ict
[
str
,
str
]],
add_generation_prompt
:
bool
=
True
)
->
str
:
)
->
str
:
"""
"""
Method to apply a chat template to a list of chat history between user and model.
Method to apply a chat template to a list of chat history between user and model.
...
...
lm_eval/models/vllm_causallms.py
View file @
6d63c2ce
from
__future__
import
annotations
import
copy
import
copy
import
gc
import
gc
import
inspect
import
inspect
...
@@ -8,7 +10,7 @@ from importlib.util import find_spec
...
@@ -8,7 +10,7 @@ from importlib.util import find_spec
from
multiprocessing
import
Process
,
Queue
from
multiprocessing
import
Process
,
Queue
from
queue
import
Empty
from
queue
import
Empty
from
time
import
sleep
from
time
import
sleep
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Literal
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Literal
import
jinja2
import
jinja2
from
more_itertools
import
distribute
from
more_itertools
import
distribute
...
@@ -51,10 +53,10 @@ eval_logger = logging.getLogger(__name__)
...
@@ -51,10 +53,10 @@ eval_logger = logging.getLogger(__name__)
def
_vllm_mp_worker
(
def
_vllm_mp_worker
(
model_args
:
dict
,
model_args
:
dict
,
sampling_params
:
"
SamplingParams
"
,
sampling_params
:
SamplingParams
,
requests
:
list
[
list
[
int
]],
requests
:
list
[
list
[
int
]],
lora_request
:
"
LoRARequest
"
,
lora_request
:
LoRARequest
,
result_queue
:
"
Queue
"
,
result_queue
:
Queue
,
dp_size
:
int
,
dp_size
:
int
,
local_dp_rank
:
int
,
local_dp_rank
:
int
,
dp_master_port
:
int
,
dp_master_port
:
int
,
...
@@ -114,30 +116,30 @@ class VLLM(TemplateLM):
...
@@ -114,30 +116,30 @@ class VLLM(TemplateLM):
self
,
self
,
pretrained
:
str
,
pretrained
:
str
,
dtype
:
Literal
[
"float16"
,
"bfloat16"
,
"float32"
,
"auto"
]
=
"auto"
,
dtype
:
Literal
[
"float16"
,
"bfloat16"
,
"float32"
,
"auto"
]
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
revision
:
str
|
None
=
None
,
trust_remote_code
:
Optional
[
bool
]
=
False
,
trust_remote_code
:
bool
|
None
=
False
,
tokenizer
:
Optional
[
str
]
=
None
,
tokenizer
:
str
|
None
=
None
,
tokenizer_mode
:
Literal
[
"auto"
,
"slow"
]
=
"auto"
,
tokenizer_mode
:
Literal
[
"auto"
,
"slow"
]
=
"auto"
,
tokenizer_revision
:
Optional
[
str
]
=
None
,
tokenizer_revision
:
str
|
None
=
None
,
add_bos_token
:
Optional
[
bool
]
=
False
,
add_bos_token
:
bool
|
None
=
False
,
prefix_token_id
:
Optional
[
int
]
=
None
,
prefix_token_id
:
int
|
None
=
None
,
tensor_parallel_size
:
int
=
1
,
tensor_parallel_size
:
int
=
1
,
quantization
:
Optional
[
str
]
=
None
,
quantization
:
str
|
None
=
None
,
max_gen_toks
:
int
=
256
,
max_gen_toks
:
int
=
256
,
swap_space
:
int
=
4
,
swap_space
:
int
=
4
,
batch_size
:
Union
[
str
,
int
]
=
1
,
batch_size
:
str
|
int
=
1
,
max_batch_size
=
None
,
max_batch_size
:
int
|
None
=
None
,
max_length
:
int
=
None
,
max_length
:
int
|
None
=
None
,
max_model_len
:
int
=
None
,
max_model_len
:
int
|
None
=
None
,
seed
:
int
=
1234
,
seed
:
int
=
1234
,
gpu_memory_utilization
:
float
=
0.9
,
gpu_memory_utilization
:
float
=
0.9
,
device
:
str
=
"cuda"
,
device
:
str
=
"cuda"
,
data_parallel_size
:
int
=
1
,
data_parallel_size
:
int
=
1
,
lora_local_path
:
str
=
None
,
lora_local_path
:
str
|
None
=
None
,
# VLLM: enable thinking tags in the prompt.
# VLLM: enable thinking tags in the prompt.
enable_thinking
:
bool
=
True
,
enable_thinking
:
bool
=
True
,
# End marker for thinking tags - splits to get response after this token (if provided).
# End marker for thinking tags - splits to get response after this token (if provided).
think_end_token
:
Optional
[
str
]
=
None
,
think_end_token
:
str
|
None
=
None
,
max_lora_rank
:
int
=
16
,
max_lora_rank
:
int
=
16
,
**
kwargs
,
**
kwargs
,
):
):
...
@@ -173,7 +175,7 @@ class VLLM(TemplateLM):
...
@@ -173,7 +175,7 @@ class VLLM(TemplateLM):
"quantization"
:
quantization
,
"quantization"
:
quantization
,
"seed"
:
int
(
seed
),
"seed"
:
int
(
seed
),
"device"
:
str
(
device
),
"device"
:
str
(
device
),
"enable_lora"
:
True
if
lora_local_path
else
False
,
"enable_lora"
:
bool
(
lora_local_path
)
,
"max_lora_rank"
:
int
(
max_lora_rank
),
"max_lora_rank"
:
int
(
max_lora_rank
),
}
}
self
.
model_args
.
update
(
kwargs
)
self
.
model_args
.
update
(
kwargs
)
...
@@ -304,7 +306,7 @@ class VLLM(TemplateLM):
...
@@ -304,7 +306,7 @@ class VLLM(TemplateLM):
return
self
.
_max_gen_toks
return
self
.
_max_gen_toks
def
apply_chat_template
(
def
apply_chat_template
(
self
,
chat_history
:
L
ist
[
D
ict
[
str
,
str
]],
add_generation_prompt
:
bool
=
True
self
,
chat_history
:
l
ist
[
d
ict
[
str
,
str
]],
add_generation_prompt
:
bool
=
True
)
->
str
:
)
->
str
:
"""
"""
Method to apply a chat template to a list of chat history between user and model.
Method to apply a chat template to a list of chat history between user and model.
...
@@ -339,14 +341,14 @@ class VLLM(TemplateLM):
...
@@ -339,14 +341,14 @@ class VLLM(TemplateLM):
def
tok_encode
(
def
tok_encode
(
self
,
self
,
string
:
Union
[
str
,
L
ist
[
str
]
]
,
string
:
str
|
l
ist
[
str
],
left_truncate_len
:
int
=
None
,
left_truncate_len
:
int
=
None
,
add_special_tokens
:
bool
=
False
,
add_special_tokens
:
bool
=
False
,
truncation
:
bool
=
False
,
truncation
:
bool
=
False
,
)
->
Union
[
L
ist
[
int
]
,
L
ist
[
L
ist
[
int
]]
]
:
)
->
l
ist
[
int
]
|
l
ist
[
l
ist
[
int
]]:
if
not
add_special_tokens
:
if
not
add_special_tokens
:
add_special_tokens
=
False
or
self
.
add_bos_token
add_special_tokens
=
False
or
self
.
add_bos_token
encoding
:
Union
[
L
ist
[
L
ist
[
int
]]
,
L
ist
[
int
]
]
=
self
.
tokenizer
(
encoding
:
l
ist
[
l
ist
[
int
]]
|
l
ist
[
int
]
=
self
.
tokenizer
(
string
,
string
,
add_special_tokens
=
add_special_tokens
,
add_special_tokens
=
add_special_tokens
,
truncation
=
truncation
,
truncation
=
truncation
,
...
@@ -364,10 +366,10 @@ class VLLM(TemplateLM):
...
@@ -364,10 +366,10 @@ class VLLM(TemplateLM):
def
_model_generate
(
def
_model_generate
(
self
,
self
,
requests
:
L
ist
[
L
ist
[
int
]]
=
None
,
requests
:
l
ist
[
l
ist
[
int
]]
=
None
,
generate
:
bool
=
False
,
generate
:
bool
=
False
,
max_tokens
:
int
=
None
,
max_tokens
:
int
=
None
,
stop
:
Optional
[
L
ist
[
str
]
]
=
None
,
stop
:
l
ist
[
str
]
|
None
=
None
,
**
kwargs
,
**
kwargs
,
):
):
if
generate
:
if
generate
:
...
@@ -385,7 +387,7 @@ class VLLM(TemplateLM):
...
@@ -385,7 +387,7 @@ class VLLM(TemplateLM):
def
run_inference_one_model
(
def
run_inference_one_model
(
model_args
:
dict
,
model_args
:
dict
,
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
requests
:
L
ist
[
L
ist
[
int
]],
requests
:
l
ist
[
l
ist
[
int
]],
lora_request
:
LoRARequest
,
lora_request
:
LoRARequest
,
):
):
llm
=
LLM
(
**
model_args
)
llm
=
LLM
(
**
model_args
)
...
@@ -454,7 +456,7 @@ class VLLM(TemplateLM):
...
@@ -454,7 +456,7 @@ class VLLM(TemplateLM):
if
dead_procs
:
if
dead_procs
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Worker processes
{
dead_procs
}
died unexpectedly"
f
"Worker processes
{
dead_procs
}
died unexpectedly"
)
)
from
None
continue
continue
results
=
[
rank_res
[
i
]
for
i
in
range
(
len
(
procs
))]
results
=
[
rank_res
[
i
]
for
i
in
range
(
len
(
procs
))]
...
@@ -481,14 +483,14 @@ class VLLM(TemplateLM):
...
@@ -481,14 +483,14 @@ class VLLM(TemplateLM):
outputs
=
self
.
model
.
generate
(
outputs
=
self
.
model
.
generate
(
prompt_token_ids
=
requests
,
prompt_token_ids
=
requests
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
use_tqdm
=
True
if
self
.
batch_size
==
"auto"
else
False
,
use_tqdm
=
self
.
batch_size
==
"auto"
,
lora_request
=
self
.
lora_request
,
lora_request
=
self
.
lora_request
,
)
)
return
outputs
return
outputs
def
loglikelihood_rolling
(
def
loglikelihood_rolling
(
self
,
requests
:
L
ist
[
Instance
],
disable_tqdm
:
bool
=
False
self
,
requests
:
l
ist
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
L
ist
[
float
]:
)
->
l
ist
[
float
]:
adaptive_batch_size
=
None
adaptive_batch_size
=
None
if
self
.
batch_size
==
"auto"
:
if
self
.
batch_size
==
"auto"
:
adaptive_batch_size
=
len
(
requests
)
adaptive_batch_size
=
len
(
requests
)
...
@@ -503,7 +505,7 @@ class VLLM(TemplateLM):
...
@@ -503,7 +505,7 @@ class VLLM(TemplateLM):
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
)),
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
)),
)
)
):
):
rolling_token_windows
:
L
ist
[
T
uple
[
L
ist
[
int
],
L
ist
[
int
]]]
=
list
(
rolling_token_windows
:
l
ist
[
t
uple
[
l
ist
[
int
],
l
ist
[
int
]]]
=
list
(
map
(
map
(
make_disjoint_window
,
make_disjoint_window
,
get_rolling_token_windows
(
get_rolling_token_windows
(
...
@@ -556,13 +558,13 @@ class VLLM(TemplateLM):
...
@@ -556,13 +558,13 @@ class VLLM(TemplateLM):
return
loglikelihoods
return
loglikelihoods
def
generate_until
(
def
generate_until
(
self
,
requests
:
L
ist
[
Instance
],
disable_tqdm
:
bool
=
False
self
,
requests
:
l
ist
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
L
ist
[
str
]:
)
->
l
ist
[
str
]:
res
=
[]
res
=
[]
# batch tokenize contexts
# batch tokenize contexts
context
,
all_gen_kwargs
=
zip
(
*
(
req
.
args
for
req
in
requests
))
context
,
all_gen_kwargs
=
zip
(
*
(
req
.
args
for
req
in
requests
))
context_encoding
:
L
ist
[
L
ist
[
int
]]
=
self
.
tok_encode
(
context_encoding
:
l
ist
[
l
ist
[
int
]]
=
self
.
tok_encode
(
context
,
add_special_tokens
=
self
.
add_bos_token
context
,
add_special_tokens
=
self
.
add_bos_token
)
)
requests
=
[
requests
=
[
...
@@ -608,7 +610,7 @@ class VLLM(TemplateLM):
...
@@ -608,7 +610,7 @@ class VLLM(TemplateLM):
raise
ValueError
(
raise
ValueError
(
f
"Expected `kwargs` to be of type `dict` but got
{
type
(
gen_kwargs
)
}
"
f
"Expected `kwargs` to be of type `dict` but got
{
type
(
gen_kwargs
)
}
"
)
)
if
"max_gen_toks"
in
kwargs
.
keys
()
:
if
"max_gen_toks"
in
kwargs
:
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
else
:
else
:
max_gen_toks
=
self
.
max_gen_toks
max_gen_toks
=
self
.
max_gen_toks
...
@@ -634,7 +636,7 @@ class VLLM(TemplateLM):
...
@@ -634,7 +636,7 @@ class VLLM(TemplateLM):
)
)
# cache generations
# cache generations
for
output
,
context
in
zip
(
cont
,
context
):
for
output
,
context
_
in
zip
(
cont
,
context
):
generated_text
:
str
=
output
.
outputs
[
0
].
text
generated_text
:
str
=
output
.
outputs
[
0
].
text
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
generated_text
=
postprocess_generated_text
(
generated_text
=
postprocess_generated_text
(
...
@@ -642,7 +644,7 @@ class VLLM(TemplateLM):
...
@@ -642,7 +644,7 @@ class VLLM(TemplateLM):
)
)
res
.
append
(
generated_text
)
res
.
append
(
generated_text
)
self
.
cache_hook
.
add_partial
(
self
.
cache_hook
.
add_partial
(
"generate_until"
,
(
context
,
gen_kwargs
),
generated_text
"generate_until"
,
(
context
_
,
gen_kwargs
),
generated_text
)
)
pbar
.
update
(
1
)
pbar
.
update
(
1
)
...
@@ -652,9 +654,9 @@ class VLLM(TemplateLM):
...
@@ -652,9 +654,9 @@ class VLLM(TemplateLM):
def
_loglikelihood_tokens
(
def
_loglikelihood_tokens
(
self
,
self
,
requests
:
L
ist
[
T
uple
[
T
uple
[
str
,
str
],
L
ist
[
int
],
L
ist
[
int
]]],
requests
:
l
ist
[
t
uple
[
t
uple
[
str
,
str
],
l
ist
[
int
],
l
ist
[
int
]]],
disable_tqdm
:
bool
=
False
,
disable_tqdm
:
bool
=
False
,
)
->
L
ist
[
T
uple
[
float
,
bool
]]:
)
->
l
ist
[
t
uple
[
float
,
bool
]]:
res
=
[]
res
=
[]
def
_collate
(
x
):
def
_collate
(
x
):
...
@@ -675,7 +677,7 @@ class VLLM(TemplateLM):
...
@@ -675,7 +677,7 @@ class VLLM(TemplateLM):
for
chunk
in
chunks
:
for
chunk
in
chunks
:
inputs
=
[]
inputs
=
[]
ctxlens
=
[]
ctxlens
=
[]
for
cache_key
,
context_enc
,
continuation_enc
in
chunk
:
for
_
cache_key
,
context_enc
,
continuation_enc
in
chunk
:
if
(
if
(
full_length
:
=
len
(
context_enc
+
continuation_enc
)
full_length
:
=
len
(
context_enc
+
continuation_enc
)
)
>
self
.
max_length
:
)
>
self
.
max_length
:
...
@@ -713,7 +715,7 @@ class VLLM(TemplateLM):
...
@@ -713,7 +715,7 @@ class VLLM(TemplateLM):
return
re_ord
.
get_original
(
res
)
return
re_ord
.
get_original
(
res
)
@
staticmethod
@
staticmethod
def
_parse_logprobs
(
tokens
:
L
ist
,
outputs
,
ctxlen
:
int
)
->
T
uple
[
float
,
bool
]:
def
_parse_logprobs
(
tokens
:
l
ist
,
outputs
,
ctxlen
:
int
)
->
t
uple
[
float
,
bool
]:
"""Process logprobs and tokens.
"""Process logprobs and tokens.
:param tokens: list
:param tokens: list
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment