Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
b0f67f2c
Unverified
Commit
b0f67f2c
authored
Aug 07, 2023
by
Hailey Schoelkopf
Committed by
GitHub
Aug 07, 2023
Browse files
Merge pull request #736 from baberabb/big-refactor_opai
[Refactor] fixed openai
parents
6c753760
ca386392
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
158 additions
and
76 deletions
+158
-76
docs/model_guide.md
docs/model_guide.md
+7
-3
lm_eval/api/model.py
lm_eval/api/model.py
+20
-19
lm_eval/evaluator.py
lm_eval/evaluator.py
+3
-1
lm_eval/models/anthropic_llms.py
lm_eval/models/anthropic_llms.py
+31
-16
lm_eval/models/openai_completions.py
lm_eval/models/openai_completions.py
+96
-36
setup.py
setup.py
+1
-1
No files found.
docs/model_guide.md
View file @
b0f67f2c
...
@@ -36,15 +36,19 @@ The LM class enforces a common interface via which we can extract responses from
...
@@ -36,15 +36,19 @@ The LM class enforces a common interface via which we can extract responses from
```
python
```
python
class
MyCustomLM
(
LM
):
class
MyCustomLM
(
LM
):
#...
#...
def
loglikelihood
(
self
,
requests
):
def
loglikelihood
(
self
,
requests
:
list
[
Instance
])
->
list
[
tuple
[
float
,
bool
]]:
#...
def
loglikelihood_rolling
(
self
,
requests
):
def
loglikelihood_rolling
(
self
,
requests
:
list
[
Instance
])
->
list
[
tuple
[
float
,
bool
]]:
#...
def
greedy_until
(
self
,
requests
):
def
greedy_until
(
self
,
requests
:
list
[
Instance
])
->
list
[
str
]:
#...
#...
#...
```
```
Where
`Instance`
is a dataclass defined in
[
`lm_eval.api.instance`
](
https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/lm_eval/api/instance.py
)
with property
`args`
which returns a tuple of (context, continuation).
We support
We support
...
...
lm_eval/api/model.py
View file @
b0f67f2c
import
abc
import
abc
import
os
import
os
from
typing
import
Union
from
typing
import
Union
,
List
,
Tuple
from
sqlitedict
import
SqliteDict
from
sqlitedict
import
SqliteDict
import
json
import
json
import
hashlib
import
hashlib
...
@@ -25,31 +25,32 @@ class LM(abc.ABC):
...
@@ -25,31 +25,32 @@ class LM(abc.ABC):
self
.
cache_hook
=
CacheHook
(
None
)
self
.
cache_hook
=
CacheHook
(
None
)
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
loglikelihood
(
self
,
requests
):
def
loglikelihood
(
self
,
requests
)
->
List
[
Tuple
[
float
,
bool
]]
:
"""Compute log-likelihood of generating a continuation from a context.
"""Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other
Downstream tasks should attempt to use loglikelihood instead of other
LM calls whenever possible.
LM calls whenever possible.
:param requests: list
:param requests: list
[Instance]
A list of
pairs
(context, continuation)
A list of
Instance objects, with property `args` which returns a tuple
(context, continuation)
.
context: str
`
context: str
`
Context string. Implementations of LM must be able to handle an
Context string. Implementations of LM must be able to handle an
empty context string.
empty context string.
continuation: str
`
continuation: str
`
The continuation over which log likelihood will be calculated. If
The continuation over which log likelihood will be calculated. If
there is a word boundary, the space should be in the continuation.
there is a word boundary, the space should be in the continuation.
For example, context="hello" continuation=" world" is correct.
For example, context="hello" continuation=" world" is correct.
:return: list
:return: list[tuple[float, bool]]
A list of pairs (logprob, isgreedy)
A list of pairs (logprob, isgreedy)
logprob: float
`
logprob: float
`
The log probability of `continuation`
The log probability of `continuation`
.
isgreedy:
`
isgreedy
`
:
Whether `continuation` would be generated by greedy sampling from `context`
Whether `continuation` would be generated by greedy sampling from `context`
.
"""
"""
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
loglikelihood_rolling
(
self
,
requests
):
def
loglikelihood_rolling
(
self
,
requests
)
->
List
[
Tuple
[
float
,
bool
]]
:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model.
- We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
...
@@ -77,11 +78,11 @@ class LM(abc.ABC):
...
@@ -77,11 +78,11 @@ class LM(abc.ABC):
1. Each token is predicted exactly once
1. Each token is predicted exactly once
2. For the last pair, we provide the full context, but only score the last two tokens
2. For the last pair, we provide the full context, but only score the last two tokens
:param requests: list
:param requests: list
[Instance]
A list of
strings
A list of
Instance objects with property `args` which returns a tuple (context, continuation).
string: str
string: str
String for which we are computing per-token loglikelihood
String for which we are computing per-token loglikelihood
:return: list
:return: list
[tuple[float, bool]]
A list of pairs (logprob, isgreedy)
A list of pairs (logprob, isgreedy)
logprob: float
logprob: float
The log probability of `continuation`
The log probability of `continuation`
...
@@ -92,17 +93,17 @@ class LM(abc.ABC):
...
@@ -92,17 +93,17 @@ class LM(abc.ABC):
# TODO: Add an optional max length
# TODO: Add an optional max length
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
greedy_until
(
self
,
requests
):
def
greedy_until
(
self
,
requests
)
->
List
[
str
]
:
"""Generate greedily until a stopping sequence
"""Generate greedily until a stopping sequence
:param requests: list
:param requests: list
[Instance]
A list of
pairs
(context, until)
A list of
Instance objects with property `args` which returns a tuple
(context, until)
.
context: str
context: str
Context string
Context string
until: [str]
until: [str]
The string sequences to generate until. These string sequences
The string sequences to generate until. These string sequences
may each span across multiple tokens, or may be part of one token.
may each span across multiple tokens, or may be part of one token.
:return: list
:return: list
[str]
A list of strings continuation
A list of strings continuation
continuation: str
continuation: str
The generated continuation.
The generated continuation.
...
...
lm_eval/evaluator.py
View file @
b0f67f2c
...
@@ -85,7 +85,9 @@ def simple_evaluate(
...
@@ -85,7 +85,9 @@ def simple_evaluate(
1234
1234
)
# TODO: this may affect training runs that are run with evaluation mid-run.
)
# TODO: this may affect training runs that are run with evaluation mid-run.
assert
tasks
!=
[],
"No tasks specified"
assert
(
tasks
!=
[]
),
"No tasks specified, or no tasks found. Please verify the task names."
if
isinstance
(
model
,
str
):
if
isinstance
(
model
,
str
):
if
model_args
is
None
:
if
model_args
is
None
:
...
...
lm_eval/models/anthropic_llms.py
View file @
b0f67f2c
import
os
from
lm_eval.api.model
import
LM
from
lm_eval.api.model
import
LM
from
lm_eval.api.registry
import
register_model
from
lm_eval.api.registry
import
register_model
from
tqdm
import
tqdm
from
tqdm
import
tqdm
import
time
import
time
from
lm_eval.logger
import
eval_logger
from
lm_eval.logger
import
eval_logger
from
typing
import
List
,
Literal
,
Any
from
typing
import
List
,
Any
,
Tuple
def
anthropic_completion
(
def
anthropic_completion
(
...
@@ -15,10 +14,25 @@ def anthropic_completion(
...
@@ -15,10 +14,25 @@ def anthropic_completion(
temperature
:
float
,
temperature
:
float
,
stop
:
List
[
str
],
stop
:
List
[
str
],
**
kwargs
:
Any
,
**
kwargs
:
Any
,
):
)
->
str
:
"""Query Anthropic API for completion.
"""Wrapper function around the Anthropic completion API client with exponential back-off
in case of RateLimitError.
Retry with back-off until they respond
params:
client: anthropic.Anthropic
Anthropic API client
model: str
Anthropic model e.g. 'claude-instant-v1', 'claude-2'
prompt: str
Prompt to feed to the model
max_tokens_to_sample: int
Maximum number of tokens to sample from the model
temperature: float
Sampling temperature
stop: List[str]
List of stop sequences
kwargs: Any
Additional model_args to pass to the API client
"""
"""
try
:
try
:
...
@@ -29,7 +43,7 @@ def anthropic_completion(
...
@@ -29,7 +43,7 @@ def anthropic_completion(
please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e .[anthropic]`"
,
please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e .[anthropic]`"
,
)
)
backoff_time
=
3
backoff_time
:
float
=
3
while
True
:
while
True
:
try
:
try
:
response
=
client
.
completions
.
create
(
response
=
client
.
completions
.
create
(
...
@@ -94,15 +108,15 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
...
@@ -94,15 +108,15 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
@
property
@
property
def
eot_token_id
(
self
):
def
eot_token_id
(
self
):
# Not sure but anthropic.
AI
_PROMPT
-> [203, 203, 50803, 30]
# Not sure but anthropic.
HUMAN
_PROMPT
?
raise
NotImplementedError
(
"No idea about anthropic tokenization."
)
raise
NotImplementedError
(
"No idea about anthropic tokenization."
)
@
property
@
property
def
max_length
(
self
):
def
max_length
(
self
)
->
int
:
return
2048
return
2048
@
property
@
property
def
max_gen_toks
(
self
):
def
max_gen_toks
(
self
)
->
int
:
return
self
.
max_tokens_to_sample
return
self
.
max_tokens_to_sample
@
property
@
property
...
@@ -124,14 +138,15 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
...
@@ -124,14 +138,15 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
):
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
):
raise
NotImplementedError
(
"No support for logits."
)
raise
NotImplementedError
(
"No support for logits."
)
def
greedy_until
(
self
,
requests
):
def
greedy_until
(
self
,
requests
)
->
List
[
str
]:
if
not
requests
:
if
not
requests
:
return
[]
return
[]
requests
=
[
req
.
args
for
req
in
requests
]
_
requests
:
List
[
Tuple
[
str
,
dict
]]
=
[
req
.
args
for
req
in
requests
]
res
=
[]
res
=
[]
for
request
in
tqdm
(
requests
):
for
request
in
tqdm
(
_
requests
):
try
:
try
:
inp
=
request
[
0
]
inp
=
request
[
0
]
request_args
=
request
[
1
]
request_args
=
request
[
1
]
...
@@ -145,16 +160,16 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
...
@@ -145,16 +160,16 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
prompt
=
inp
,
prompt
=
inp
,
max_tokens_to_sample
=
max_gen_toks
,
max_tokens_to_sample
=
max_gen_toks
,
temperature
=
temperature
,
# TODO: implement non-greedy sampling for Anthropic
temperature
=
temperature
,
# TODO: implement non-greedy sampling for Anthropic
stop
=
until
,
stop
=
until
,
# type: ignore
**
self
.
kwargs
,
**
self
.
kwargs
,
)
)
res
.
append
(
response
)
res
.
append
(
response
)
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
request
,
response
)
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
request
,
response
)
except
anthropic
.
APIConnectionError
as
e
:
# noqa: F821
except
anthropic
.
APIConnectionError
as
e
:
#
type: ignore #
noqa: F821
eval_logger
.
critical
(
f
"Server unreachable:
{
e
.
__cause__
}
"
)
eval_logger
.
critical
(
f
"Server unreachable:
{
e
.
__cause__
}
"
)
break
break
except
anthropic
.
APIStatusError
as
e
:
# noqa: F821
except
anthropic
.
APIStatusError
as
e
:
#
type: ignore #
noqa: F821
eval_logger
.
critical
(
f
"API error
{
e
.
status_code
}
:
{
e
.
message
}
"
)
eval_logger
.
critical
(
f
"API error
{
e
.
status_code
}
:
{
e
.
message
}
"
)
break
break
...
...
lm_eval/models/openai_completions.py
View file @
b0f67f2c
import
os
import
os
import
time
import
time
import
transformers
from
typing
import
List
,
Tuple
import
numpy
as
np
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
lm_eval
import
utils
from
lm_eval
import
utils
from
lm_eval.api.model
import
LM
from
lm_eval.api.model
import
LM
from
lm_eval.api.registry
import
register_model
from
lm_eval.api.registry
import
register_model
def
get_result
(
response
,
ctxlen
)
:
def
get_result
(
response
:
dict
,
ctxlen
:
int
)
->
Tuple
[
float
,
bool
]
:
"""Process results from OpenAI API response.
"""Process results from OpenAI API response.
:param response: dict
:param response: dict
...
@@ -43,7 +40,13 @@ def oa_completion(**kwargs):
...
@@ -43,7 +40,13 @@ def oa_completion(**kwargs):
Retry with back-off until they respond
Retry with back-off until they respond
"""
"""
import
openai
try
:
import
openai
,
tiktoken
# noqa: E401
except
ModuleNotFoundError
:
raise
Exception
(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed.
\
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`"
,
)
backoff_time
=
3
backoff_time
=
3
while
True
:
while
True
:
...
@@ -61,7 +64,12 @@ def oa_completion(**kwargs):
...
@@ -61,7 +64,12 @@ def oa_completion(**kwargs):
class
OpenaiCompletionsLM
(
LM
):
class
OpenaiCompletionsLM
(
LM
):
REQ_CHUNK_SIZE
=
20
REQ_CHUNK_SIZE
=
20
def
__init__
(
self
,
engine
,
truncate
=
False
):
def
__init__
(
self
,
engine
:
str
=
"text-davinci-003"
,
truncate
:
bool
=
False
,
batch_size
:
int
=
1
,
):
"""
"""
:param engine: str
:param engine: str
...
@@ -70,28 +78,25 @@ class OpenaiCompletionsLM(LM):
...
@@ -70,28 +78,25 @@ class OpenaiCompletionsLM(LM):
Truncate input if too long (if False and input is too long, throw error)
Truncate input if too long (if False and input is too long, throw error)
"""
"""
super
().
__init__
()
super
().
__init__
()
try
:
import
openai
import
openai
,
tiktoken
# noqa: E401
except
ModuleNotFoundError
:
raise
Exception
(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed.
\
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`"
,
)
self
.
engine
=
engine
self
.
engine
=
engine
self
.
tokenizer
=
transformers
.
GPT2TokenizerFast
.
from_pretrained
(
"gpt2"
)
self
.
tokenizer
=
tiktoken
.
encoding_for_model
(
self
.
engine
)
self
.
vocab_size
=
self
.
tokenizer
.
n_vocab
self
.
vocab_size
=
self
.
tokenizer
.
vocab_size
# to make the annoying "Using pad_token, but it is not set yet." error go away
self
.
tokenizer
.
pad_token
=
"<|endoftext|>"
assert
self
.
tokenizer
.
encode
(
"hello
\n\n
hello"
)
==
[
31373
,
198
,
198
,
31373
]
self
.
truncate
=
truncate
self
.
truncate
=
truncate
self
.
end_of_text_token_id
=
self
.
tokenizer
.
convert_tokens_to_ids
(
self
.
end_of_text_token_id
=
self
.
tokenizer
.
eot_token
[
"<|endoftext|>"
]
)[
0
]
# Read from environment variable OPENAI_API_SECRET_KEY
# Read from environment variable OPENAI_API_SECRET_KEY
openai
.
api_key
=
os
.
environ
[
"OPENAI_API_SECRET_KEY"
]
openai
.
api_key
=
os
.
environ
[
"OPENAI_API_SECRET_KEY"
]
@
property
@
property
def
eot_token_id
(
self
):
def
eot_token_id
(
self
):
return
self
.
tokenizer
.
eos
_token_id
return
self
.
end_of_text
_token_id
@
property
@
property
def
max_length
(
self
):
def
max_length
(
self
):
...
@@ -112,19 +117,49 @@ class OpenaiCompletionsLM(LM):
...
@@ -112,19 +117,49 @@ class OpenaiCompletionsLM(LM):
# Isn't used because we override _loglikelihood_tokens
# Isn't used because we override _loglikelihood_tokens
raise
NotImplementedError
()
raise
NotImplementedError
()
def
tok_encode
(
self
,
string
:
str
):
def
tok_encode
(
self
,
string
:
str
)
->
List
[
int
]
:
return
self
.
tokenizer
.
encode
(
string
,
add_special_tokens
=
False
)
return
self
.
tokenizer
.
encode
(
string
)
def
tok_decode
(
self
,
tokens
)
:
def
tok_decode
(
self
,
tokens
:
List
[
int
])
->
str
:
return
self
.
tokenizer
.
decode
(
tokens
)
return
self
.
tokenizer
.
decode
(
tokens
)
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
):
def
_encode_pair
(
self
,
context
:
str
,
continuation
:
str
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
n_spaces
=
len
(
context
)
-
len
(
context
.
rstrip
())
if
n_spaces
>
0
:
continuation
=
context
[
-
n_spaces
:]
+
continuation
context
=
context
[:
-
n_spaces
]
whole_enc
=
self
.
tok_encode
(
context
+
continuation
)
context_enc
=
self
.
tok_encode
(
context
)
context_enc_len
=
len
(
context_enc
)
continuation_enc
=
whole_enc
[
context_enc_len
:]
return
context_enc
,
continuation_enc
def
loglikelihood
(
self
,
requests
)
->
List
[
Tuple
[
float
,
bool
]]:
new_reqs
=
[]
for
context
,
continuation
in
[
req
.
args
for
req
in
requests
]:
if
context
==
""
:
# end of text as context
context_enc
,
continuation_enc
=
[
self
.
eot_token_id
],
self
.
tok_encode
(
continuation
)
else
:
context_enc
,
continuation_enc
=
self
.
_encode_pair
(
context
,
continuation
)
new_reqs
.
append
(((
context
,
continuation
),
context_enc
,
continuation_enc
))
return
self
.
_loglikelihood_tokens
(
new_reqs
)
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
)
->
List
[
Tuple
[
float
,
bool
]]:
res
=
[]
res
=
[]
def
_collate
(
x
):
def
_collate
(
x
):
# this doesn't efficiently handle last-token differences yet, but those are kinda annoying because
# this doesn't efficiently handle last-token differences yet, but those are kinda annoying because
# it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
# it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
# we care about and so we need some kind of backup for when it isn't
# we care about
,
and so we need some kind of backup for when it isn't
toks
=
x
[
1
]
+
x
[
2
]
toks
=
x
[
1
]
+
x
[
2
]
return
-
len
(
toks
),
tuple
(
toks
)
return
-
len
(
toks
),
tuple
(
toks
)
...
@@ -166,13 +201,13 @@ class OpenaiCompletionsLM(LM):
...
@@ -166,13 +201,13 @@ class OpenaiCompletionsLM(LM):
# partial caching
# partial caching
if
cache_key
is
not
None
:
if
cache_key
is
not
None
:
self
.
cache_hook
.
add_partial
(
"loglikelihood"
,
cache_key
,
answer
)
self
.
cache_hook
.
add_partial
(
"loglikelihood"
,
cache_key
,
answer
)
return
re_ord
.
get_original
(
res
)
return
re_ord
.
get_original
(
res
)
def
greedy_until
(
self
,
requests
):
def
greedy_until
(
self
,
requests
)
->
List
[
str
]
:
if
not
requests
:
if
not
requests
:
return
[]
return
[]
res
=
[]
res
=
[]
requests
=
[
req
.
args
for
req
in
requests
]
def
_collate
(
x
):
def
_collate
(
x
):
toks
=
self
.
tok_encode
(
x
[
0
])
toks
=
self
.
tok_encode
(
x
[
0
])
...
@@ -203,12 +238,7 @@ class OpenaiCompletionsLM(LM):
...
@@ -203,12 +238,7 @@ class OpenaiCompletionsLM(LM):
inp
=
context_enc
[
-
(
self
.
max_length
-
self
.
max_gen_toks
)
:]
inp
=
context_enc
[
-
(
self
.
max_length
-
self
.
max_gen_toks
)
:]
inps
.
append
(
inp
)
inps
.
append
(
inp
)
try
:
until
=
request_args
.
get
(
"until"
,
[
"<|endoftext|>"
])
until
=
request_args
[
"until"
][
0
]
# TODO: does this handle a list of stop seqs correctly?
except
KeyError
:
until
=
"<|endoftext|>"
response
=
oa_completion
(
response
=
oa_completion
(
engine
=
self
.
engine
,
engine
=
self
.
engine
,
...
@@ -222,7 +252,7 @@ class OpenaiCompletionsLM(LM):
...
@@ -222,7 +252,7 @@ class OpenaiCompletionsLM(LM):
for
resp
,
(
context
,
args_
)
in
zip
(
response
.
choices
,
chunk
):
for
resp
,
(
context
,
args_
)
in
zip
(
response
.
choices
,
chunk
):
s
=
resp
[
"text"
]
s
=
resp
[
"text"
]
until_
=
args_
.
get
(
[
"until"
]
,
[])
until_
=
args_
.
get
(
"until"
,
[
"<|endoftext|>"
])
for
term
in
until_
:
for
term
in
until_
:
if
len
(
term
)
>
0
:
if
len
(
term
)
>
0
:
...
@@ -234,7 +264,6 @@ class OpenaiCompletionsLM(LM):
...
@@ -234,7 +264,6 @@ class OpenaiCompletionsLM(LM):
)
)
res
.
append
(
s
)
res
.
append
(
s
)
return
re_ord
.
get_original
(
res
)
return
re_ord
.
get_original
(
res
)
def
_model_call
(
self
,
inps
):
def
_model_call
(
self
,
inps
):
...
@@ -244,3 +273,34 @@ class OpenaiCompletionsLM(LM):
...
@@ -244,3 +273,34 @@ class OpenaiCompletionsLM(LM):
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
# Isn't used because we override greedy_until
# Isn't used because we override greedy_until
raise
NotImplementedError
()
raise
NotImplementedError
()
def
loglikelihood_rolling
(
self
,
requests
)
->
List
[
float
]:
loglikelihoods
=
[]
for
(
string
,)
in
tqdm
([
req
.
args
for
req
in
requests
]):
rolling_token_windows
=
list
(
map
(
utils
.
make_disjoint_window
,
utils
.
get_rolling_token_windows
(
token_list
=
self
.
tok_encode
(
string
),
prefix_token
=
self
.
eot_token_id
,
max_seq_len
=
self
.
max_length
,
context_len
=
1
,
),
)
)
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
rolling_token_windows
=
[(
None
,)
+
x
for
x
in
rolling_token_windows
]
string_nll
=
self
.
_loglikelihood_tokens
(
rolling_token_windows
,
disable_tqdm
=
True
,
)
# discard is_greedy
string_nll
=
[
x
[
0
]
for
x
in
string_nll
]
string_nll
=
sum
(
string_nll
)
loglikelihoods
.
append
(
string_nll
)
return
loglikelihoods
setup.py
View file @
b0f67f2c
...
@@ -36,7 +36,6 @@ setuptools.setup(
...
@@ -36,7 +36,6 @@ setuptools.setup(
"evaluate>=0.4.0"
,
"evaluate>=0.4.0"
,
"jsonlines"
,
"jsonlines"
,
"numexpr"
,
"numexpr"
,
"openai>=0.6.4"
,
"omegaconf>=2.2"
,
"omegaconf>=2.2"
,
"peft>=0.2.0"
,
"peft>=0.2.0"
,
"pybind11>=2.6.2"
,
"pybind11>=2.6.2"
,
...
@@ -67,5 +66,6 @@ setuptools.setup(
...
@@ -67,5 +66,6 @@ setuptools.setup(
],
],
"gptq"
:
[
"auto-gptq[triton] @ git+https://github.com/PanQiWei/AutoGPTQ"
],
"gptq"
:
[
"auto-gptq[triton] @ git+https://github.com/PanQiWei/AutoGPTQ"
],
"anthropic"
:
[
"anthropic"
],
"anthropic"
:
[
"anthropic"
],
"openai"
:
[
"openai"
,
"tiktoken"
],
},
},
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment