Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
cec27dad
Commit
cec27dad
authored
Sep 29, 2023
by
mgoin
Browse files
Fix implementation
parent
101b2884
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
30 additions
and
6 deletions
+30
-6
lm_eval/models/deepsparse.py
lm_eval/models/deepsparse.py
+30
-6
No files found.
lm_eval/models/deepsparse.py
View file @
cec27dad
import
torch
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
tqdm
import
tqdm
import
random
import
deepsparse
import
deepsparse
from
typing
import
Optional
,
Union
from
lm_eval
import
utils
from
lm_eval.base
import
BaseLM
from
lm_eval.base
import
BaseLM
class
DeepSparseLM
(
BaseLM
):
class
DeepSparseLM
(
BaseLM
):
_DEFAULT_MAX_LENGTH
=
2048
_DEFAULT_MAX_LENGTH
=
2048
def
__init__
(
def
__init__
(
...
@@ -23,7 +26,7 @@ class DeepSparseLM(BaseLM):
...
@@ -23,7 +26,7 @@ class DeepSparseLM(BaseLM):
self
.
model
=
deepsparse
.
Pipeline
.
create
(
self
.
model
=
deepsparse
.
Pipeline
.
create
(
task
=
"text-generation"
,
task
=
"text-generation"
,
model_path
=
pretrained
,
model_path
=
pretrained
,
sequence_length
=
max_length
or
_DEFAULT_MAX_LENGTH
,
sequence_length
=
max_length
or
self
.
_DEFAULT_MAX_LENGTH
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
batch_size
=
batch_size
,
batch_size
=
batch_size
,
)
)
...
@@ -36,8 +39,11 @@ class DeepSparseLM(BaseLM):
...
@@ -36,8 +39,11 @@ class DeepSparseLM(BaseLM):
self
.
_max_gen_toks
=
max_gen_toks
self
.
_max_gen_toks
=
max_gen_toks
@
property
@
property
def
eot_token_id
(
self
):
def
eot_token
(
self
)
->
str
:
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return
self
.
tokenizer
.
eos_token
@
property
def
eot_token_id
(
self
)
->
int
:
return
self
.
tokenizer
.
eos_token_id
return
self
.
tokenizer
.
eos_token_id
@
property
@
property
...
@@ -125,6 +131,8 @@ class DeepSparseLM(BaseLM):
...
@@ -125,6 +131,8 @@ class DeepSparseLM(BaseLM):
do_sample
=
False
,
do_sample
=
False
,
)
)
responses
=
responses
if
type
(
responses
)
is
list
else
[
responses
]
for
response
in
responses
:
for
response
in
responses
:
response
=
response
.
generations
[
0
].
text
response
=
response
.
generations
[
0
].
text
# Ensure the generated responses do not contain the stop sequences.
# Ensure the generated responses do not contain the stop sequences.
...
@@ -136,3 +144,19 @@ class DeepSparseLM(BaseLM):
...
@@ -136,3 +144,19 @@ class DeepSparseLM(BaseLM):
return
reorder
.
get_original
(
results
)
return
reorder
.
get_original
(
results
)
def
loglikelihood
(
self
,
requests
):
raise
NotImplementedError
()
def
loglikelihood_rolling
(
self
,
requests
):
raise
NotImplementedError
()
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
):
raise
NotImplementedError
(
"No support for logits."
)
def
_model_call
(
self
,
inps
):
# Isn't used because we override _loglikelihood_tokens
raise
NotImplementedError
()
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
# Isn't used because we override greedy_until
raise
NotImplementedError
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment