Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
101b2884
Unverified
Commit
101b2884
authored
Sep 29, 2023
by
Michael Goin
Committed by
GitHub
Sep 29, 2023
Browse files
DeepSparseLM Implementation
parent
2c18e367
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
140 additions
and
0 deletions
+140
-0
lm_eval/models/__init__.py
lm_eval/models/__init__.py
+2
-0
lm_eval/models/deepsparse.py
lm_eval/models/deepsparse.py
+138
-0
No files found.
lm_eval/models/__init__.py
View file @
101b2884
...
...
@@ -3,6 +3,7 @@ from . import gpt3
from
.
import
anthropic_llms
from
.
import
huggingface
from
.
import
textsynth
from
.
import
deepsparse
from
.
import
dummy
MODEL_REGISTRY
=
{
...
...
@@ -14,6 +15,7 @@ MODEL_REGISTRY = {
"gpt3"
:
gpt3
.
GPT3LM
,
"anthropic"
:
anthropic_llms
.
AnthropicLM
,
"textsynth"
:
textsynth
.
TextSynthLM
,
"deepsparse"
:
deepsparse
.
DeepSparseLM
,
"dummy"
:
dummy
.
DummyLM
,
}
...
...
lm_eval/models/deepsparse.py
0 → 100644
View file @
101b2884
import
torch
import
deepsparse
from
typing
import
Optional
,
Union
from
lm_eval.base
import
BaseLM
class
DeepSparseLM
(
BaseLM
):
_DEFAULT_MAX_LENGTH
=
2048
def
__init__
(
self
,
pretrained
:
str
,
tokenizer
:
Optional
[
str
]
=
None
,
batch_size
:
Optional
[
Union
[
int
,
str
]]
=
1
,
max_gen_toks
:
Optional
[
int
]
=
256
,
max_length
:
Optional
[
int
]
=
None
,
trust_remote_code
:
Optional
[
bool
]
=
False
,
):
super
().
__init__
()
# Initialize new model and tokenizer instances
self
.
model
=
deepsparse
.
Pipeline
.
create
(
task
=
"text-generation"
,
model_path
=
pretrained
,
sequence_length
=
max_length
or
_DEFAULT_MAX_LENGTH
,
trust_remote_code
=
trust_remote_code
,
batch_size
=
batch_size
,
)
self
.
tokenizer
=
tokenizer
if
tokenizer
else
self
.
model
.
tokenizer
self
.
vocab_size
=
self
.
tokenizer
.
vocab_size
self
.
_batch_size
=
int
(
batch_size
)
self
.
_max_length
=
max_length
self
.
_max_gen_toks
=
max_gen_toks
@
property
def
eot_token_id
(
self
):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return
self
.
tokenizer
.
eos_token_id
@
property
def
max_length
(
self
):
if
self
.
_max_length
:
# if max length manually set, return it
return
self
.
_max_length
# seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
# for attr in seqlen_config_attrs:
# if hasattr(self.model.config, attr):
# return getattr(self.model.config, attr)
# if hasattr(self.tokenizer, "model_max_length"):
# if self.tokenizer.model_max_length == 1000000000000000019884624838656:
# return self._DEFAULT_MAX_LENGTH
# return self.tokenizer.model_max_length
return
self
.
_DEFAULT_MAX_LENGTH
@
property
def
max_gen_toks
(
self
):
return
self
.
_max_gen_toks
@
property
def
batch_size
(
self
):
return
self
.
_batch_size
@
property
def
device
(
self
):
return
"cpu"
def
tok_encode
(
self
,
string
:
str
):
return
self
.
tokenizer
.
encode
(
string
,
add_special_tokens
=
False
)
def
tok_decode
(
self
,
tokens
):
return
self
.
tokenizer
.
decode
(
tokens
)
def
greedy_until
(
self
,
requests
:
List
[
Tuple
[
str
,
Union
[
List
[
str
],
str
]]]
)
->
List
[
str
]:
def
_collate
(
x
):
tokens
=
self
.
tok_encode
(
x
[
0
])
return
len
(
tokens
),
x
[
0
]
results
=
[]
reorder
=
utils
.
Reorderer
(
requests
,
_collate
)
# adaptive_batch_size = None
# if self.batch_size == "auto":
# # using rolling window with maximum context
# print("Passed argument batch_size = auto. Detecting largest batch size")
# batch_size = self._detect_batch_size()
# print(f"Determined Largest batch size: {batch_size}")
# adaptive_batch_size = batch_size
for
chunk
in
utils
.
chunks
(
tqdm
(
reorder
.
get_reordered
(),
disable
=
False
),
self
.
batch_size
,
):
context
=
[
c
[
0
]
for
c
in
chunk
]
request_args
=
chunk
[
0
][
1
]
stop
=
request_args
.
get
(
"until"
,
None
)
stop_sequences
=
stop
if
isinstance
(
stop
,
list
)
else
[
stop
]
max_generation_length
=
request_args
.
get
(
"max_length"
,
None
)
assert
(
isinstance
(
max_generation_length
,
int
)
or
max_generation_length
is
None
)
assert
isinstance
(
stop_sequences
,
list
)
or
stop_sequences
is
None
# TODO: Find a better way to handle stop sequences for 0-shot.
if
stop_sequences
is
None
:
until
=
[
self
.
eot_token
]
else
:
until
=
stop_sequences
+
[
self
.
eot_token
]
if
max_generation_length
is
None
:
max_tokens
=
self
.
max_gen_toks
else
:
max_tokens
=
max_generation_length
# token_context = self.tok_encode_batch(context)
responses
=
self
.
model
(
sequences
=
context
,
max_new_tokens
=
max_tokens
,
stop
=
until
,
do_sample
=
False
,
)
for
response
in
responses
:
response
=
response
.
generations
[
0
].
text
# Ensure the generated responses do not contain the stop sequences.
for
term
in
until
:
response
=
response
.
split
(
term
)[
0
]
# partial caching
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
(
context
,
until
),
response
)
results
.
append
(
response
)
return
reorder
.
get_original
(
results
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment