Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
e9953abb
Commit
e9953abb
authored
Nov 09, 2023
by
mgoin
Browse files
Override greedy_until
parent
2c8e66d7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
59 additions
and
14 deletions
+59
-14
lm_eval/models/deepsparse.py
lm_eval/models/deepsparse.py
+59
-14
No files found.
lm_eval/models/deepsparse.py
View file @
e9953abb
...
@@ -67,19 +67,65 @@ class DeepSparseLM(BaseLM):
...
@@ -67,19 +67,65 @@ class DeepSparseLM(BaseLM):
logits_numpy
=
numpy
.
stack
([
generation
.
score
for
generation
in
out
.
generations
])
logits_numpy
=
numpy
.
stack
([
generation
.
score
for
generation
in
out
.
generations
])
return
torch
.
from_numpy
(
logits_numpy
)
return
torch
.
from_numpy
(
logits_numpy
)
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
def
greedy_until
(
# Encode the prompt tokens to strings
self
,
requests
:
List
[
Tuple
[
str
,
Union
[
List
[
str
],
str
]]]
prompt
=
self
.
tokenizer
.
batch_decode
(
context
.
numpy
())
)
->
List
[
str
]:
def
_collate
(
x
):
# Run generation
tokens
=
self
.
tok_encode
(
x
[
0
])
out
=
self
.
model
(
return
len
(
tokens
),
x
[
0
]
prompt
=
prompt
,
max_new_tokens
=
max_length
,
force_max_tokens
=
True
)
results
=
[]
# Return tokens for prompt + generated text
reorder
=
utils
.
Reorderer
(
requests
,
_collate
)
return
numpy
.
array
(
[
self
.
tokenizer
(
prompt
[
0
]
+
out
.
generations
[
0
].
text
)[
"input_ids"
]]
for
chunk
in
utils
.
chunks
(
)
tqdm
(
reorder
.
get_reordered
(),
disable
=
False
),
self
.
batch_size
,
):
context
=
[
c
[
0
]
for
c
in
chunk
]
request_args
=
chunk
[
0
][
1
]
stop
=
request_args
.
get
(
"until"
,
None
)
stop_sequences
=
stop
if
isinstance
(
stop
,
list
)
else
[
stop
]
max_generation_length
=
request_args
.
get
(
"max_length"
,
None
)
assert
(
isinstance
(
max_generation_length
,
int
)
or
max_generation_length
is
None
)
assert
isinstance
(
stop_sequences
,
list
)
or
stop_sequences
is
None
# TODO: Find a better way to handle stop sequences for 0-shot.
if
stop_sequences
is
None
:
until
=
[
self
.
eot_token
]
else
:
until
=
stop_sequences
+
[
self
.
eot_token
]
if
max_generation_length
is
None
:
max_tokens
=
self
.
max_gen_toks
else
:
max_tokens
=
max_generation_length
responses
=
self
.
model
(
sequences
=
context
,
max_new_tokens
=
max_tokens
,
stop
=
until
,
do_sample
=
False
,
)
responses
=
responses
if
type
(
responses
)
is
list
else
[
responses
]
for
response
in
responses
:
response
=
response
.
generations
[
0
].
text
# Ensure the generated responses do not contain the stop sequences.
for
term
in
until
:
response
=
response
.
split
(
term
)[
0
]
# partial caching
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
(
context
,
until
),
response
)
results
.
append
(
response
)
return
reorder
.
get_original
(
results
)
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
# Isn't used because we override greedy_until
raise
NotImplementedError
()
@
property
@
property
def
eot_token
(
self
)
->
str
:
def
eot_token
(
self
)
->
str
:
...
@@ -106,8 +152,7 @@ class DeepSparseLM(BaseLM):
...
@@ -106,8 +152,7 @@ class DeepSparseLM(BaseLM):
pass
pass
def
tok_encode
(
self
,
string
:
str
):
def
tok_encode
(
self
,
string
:
str
):
return
self
.
tokenizer
.
encode
(
string
)
return
self
.
tokenizer
.
encode
(
string
,
add_special_tokens
=
False
)
def
tok_decode
(
self
,
tokens
):
def
tok_decode
(
self
,
tokens
):
return
self
.
tokenizer
.
decode
(
tokens
)
return
self
.
tokenizer
.
decode
(
tokens
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment