Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
ff1b649e
Commit
ff1b649e
authored
Jun 28, 2024
by
Nathan Habib
Browse files
cleanup
parent
3c390c43
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
25 additions
and
32 deletions
+25
-32
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+24
-31
lm_eval/models/utils.py
lm_eval/models/utils.py
+1
-1
No files found.
lm_eval/models/huggingface.py
View file @
ff1b649e
...
@@ -111,11 +111,13 @@ class HFLM(TemplateLM):
...
@@ -111,11 +111,13 @@ class HFLM(TemplateLM):
gpus
=
torch
.
cuda
.
device_count
()
gpus
=
torch
.
cuda
.
device_count
()
accelerator_kwargs
=
InitProcessGroupKwargs
(
timeout
=
timedelta
(
weeks
=
52
))
accelerator_kwargs
=
InitProcessGroupKwargs
(
timeout
=
timedelta
(
weeks
=
52
))
accelerator
=
Accelerator
(
kwargs_handlers
=
[
accelerator_kwargs
])
accelerator
=
Accelerator
(
kwargs_handlers
=
[
accelerator_kwargs
])
if
accelerator
.
num_processes
>
1
:
self
.
accelerator
=
accelerator
self
.
accelerator
=
accelerator
if
"npu"
in
accelerator
.
device
.
type
:
if
"npu"
in
accelerator
.
device
.
type
:
gpus
=
torch
.
npu
.
device_count
()
gpus
=
torch
.
npu
.
device_count
()
# using one process with no model parallelism
if
not
(
parallelize
or
accelerator
.
num_processes
>
1
):
if
not
(
parallelize
or
accelerator
.
num_processes
>
1
):
# use user-passed device
# use user-passed device
device_list
=
set
(
device_list
=
set
(
...
@@ -513,12 +515,10 @@ class HFLM(TemplateLM):
...
@@ -513,12 +515,10 @@ class HFLM(TemplateLM):
revision
:
str
=
"main"
,
revision
:
str
=
"main"
,
trust_remote_code
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
)
->
None
:
)
->
None
:
with
self
.
accelerator
.
main_process_first
():
self
.
_config
=
transformers
.
AutoConfig
.
from_pretrained
(
self
.
_config
=
transformers
.
AutoConfig
.
from_pretrained
(
pretrained
,
pretrained
,
revision
=
revision
,
revision
=
revision
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
force_download
=
False
,
)
)
def
_create_model
(
def
_create_model
(
...
@@ -578,14 +578,11 @@ class HFLM(TemplateLM):
...
@@ -578,14 +578,11 @@ class HFLM(TemplateLM):
model_kwargs
[
"bnb_4bit_compute_dtype"
]
model_kwargs
[
"bnb_4bit_compute_dtype"
]
)
)
with
self
.
accelerator
.
main_process_first
():
#model_kwargs["device_map"] = "balanced_low_0"
self
.
_model
=
self
.
AUTO_MODEL_CLASS
.
from_pretrained
(
self
.
_model
=
self
.
AUTO_MODEL_CLASS
.
from_pretrained
(
pretrained
,
pretrained
,
revision
=
revision
,
revision
=
revision
,
torch_dtype
=
get_dtype
(
dtype
),
torch_dtype
=
get_dtype
(
dtype
),
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
force_download
=
False
,
**
model_kwargs
,
**
model_kwargs
,
)
)
else
:
else
:
...
@@ -679,7 +676,6 @@ class HFLM(TemplateLM):
...
@@ -679,7 +676,6 @@ class HFLM(TemplateLM):
revision
=
revision
,
revision
=
revision
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
use_fast
=
use_fast_tokenizer
,
use_fast
=
use_fast_tokenizer
,
force_download
=
False
)
)
else
:
else
:
assert
isinstance
(
assert
isinstance
(
...
@@ -709,7 +705,7 @@ class HFLM(TemplateLM):
...
@@ -709,7 +705,7 @@ class HFLM(TemplateLM):
)
)
max_context_enc
=
len
(
context_enc
[
-
(
self
.
max_length
+
1
)
:])
max_context_enc
=
len
(
context_enc
[
-
(
self
.
max_length
+
1
)
:])
max_cont_enc
=
len
(
continuation_enc
[
-
(
self
.
max_length
+
1
)
:])
max_cont_enc
=
len
(
continuation_enc
[
-
(
self
.
max_length
+
1
)
:])
security_margin_factor
=
6
# batch sizes for log prob evals sometimes generate OOMs
security_margin_factor
=
4
# batch sizes for log prob evals sometimes generate OOMs
elif
len
(
requests
[
0
])
==
2
:
# generative evals
elif
len
(
requests
[
0
])
==
2
:
# generative evals
# using rolling window with maximum context
# using rolling window with maximum context
longest_context
=
max
([
len
(
self
.
tok_encode
(
request
[
0
]))
+
request
[
1
].
get
(
"max_gen_toks"
,
self
.
max_length
)
for
request
in
requests
[
pos
:]])
longest_context
=
max
([
len
(
self
.
tok_encode
(
request
[
0
]))
+
request
[
1
].
get
(
"max_gen_toks"
,
self
.
max_length
)
for
request
in
requests
[
pos
:]])
...
@@ -721,7 +717,7 @@ class HFLM(TemplateLM):
...
@@ -721,7 +717,7 @@ class HFLM(TemplateLM):
max_length
=
longest_context
max_length
=
longest_context
max_context_enc
=
max_length
max_context_enc
=
max_length
max_cont_enc
=
max_length
max_cont_enc
=
max_length
security_margin_factor
=
6
security_margin_factor
=
4
# if OOM, then halves batch_size and tries again
# if OOM, then halves batch_size and tries again
...
@@ -751,7 +747,6 @@ class HFLM(TemplateLM):
...
@@ -751,7 +747,6 @@ class HFLM(TemplateLM):
return
batch_size
return
batch_size
try
:
try
:
print
(
f
"finding batch size on process
{
self
.
accelerator
.
local_process_index
}
"
)
batch_size
=
forward_batch
()
batch_size
=
forward_batch
()
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
if
"No executable batch size found"
in
str
(
e
):
if
"No executable batch size found"
in
str
(
e
):
...
@@ -762,7 +757,6 @@ class HFLM(TemplateLM):
...
@@ -762,7 +757,6 @@ class HFLM(TemplateLM):
if
self
.
world_size
>
1
:
if
self
.
world_size
>
1
:
# if multi-GPU, always take minimum over all selected batch sizes
# if multi-GPU, always take minimum over all selected batch sizes
max_rnk_bs
=
torch
.
tensor
([
batch_size
],
device
=
self
.
device
)
max_rnk_bs
=
torch
.
tensor
([
batch_size
],
device
=
self
.
device
)
print
(
f
"gathering on process
{
self
.
accelerator
.
local_process_index
}
"
)
gathered
=
(
gathered
=
(
self
.
accelerator
.
gather
(
max_rnk_bs
).
cpu
().
detach
().
numpy
().
tolist
()
self
.
accelerator
.
gather
(
max_rnk_bs
).
cpu
().
detach
().
numpy
().
tolist
()
)
)
...
@@ -1044,7 +1038,7 @@ class HFLM(TemplateLM):
...
@@ -1044,7 +1038,7 @@ class HFLM(TemplateLM):
else
None
else
None
)
)
chunks
=
re_ord
.
get_batched
(
n
=
batch_size
,
batch_fn
=
batch_fn
,
accelerator
=
self
.
accelerator
)
chunks
=
re_ord
.
get_batched
(
n
=
batch_size
,
batch_fn
=
batch_fn
)
pbar
=
tqdm
(
pbar
=
tqdm
(
total
=
len
(
requests
),
total
=
len
(
requests
),
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
)),
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
)),
...
@@ -1064,8 +1058,6 @@ class HFLM(TemplateLM):
...
@@ -1064,8 +1058,6 @@ class HFLM(TemplateLM):
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
# again because vectorizing is annoying
# again because vectorizing is annoying
from
pprint
import
pprint
for
_
,
context_enc
,
continuation_enc
in
chunk
:
for
_
,
context_enc
,
continuation_enc
in
chunk
:
# sanity check
# sanity check
assert
len
(
context_enc
)
>
0
assert
len
(
context_enc
)
>
0
...
@@ -1210,8 +1202,6 @@ class HFLM(TemplateLM):
...
@@ -1210,8 +1202,6 @@ class HFLM(TemplateLM):
)
->
List
[
str
]:
)
->
List
[
str
]:
res
=
[]
res
=
[]
self
.
accelerator
.
wait_for_everyone
()
def
_collate
(
req
:
Tuple
[
str
,
dict
]):
def
_collate
(
req
:
Tuple
[
str
,
dict
]):
"""Defines the key for the sorted method"""
"""Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages:
# the negative sign on len(toks) sorts descending - this has a few advantages:
...
@@ -1235,7 +1225,7 @@ class HFLM(TemplateLM):
...
@@ -1235,7 +1225,7 @@ class HFLM(TemplateLM):
)
)
batch_fn
=
(
batch_fn
=
(
self
.
_batch_scheduler
self
.
_batch_scheduler
if
self
.
batch_size
==
"auto"
# and not adaptive_batch_size
if
self
.
batch_size
==
"auto"
else
None
else
None
)
)
...
@@ -1277,9 +1267,10 @@ class HFLM(TemplateLM):
...
@@ -1277,9 +1267,10 @@ class HFLM(TemplateLM):
until
=
[
eos
]
until
=
[
eos
]
else
:
else
:
until
.
append
(
eos
)
until
.
append
(
eos
)
if
"max_gen_toks"
in
kwargs
.
keys
():
if
"max_gen_toks"
in
kwargs
.
keys
():
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
if
max_gen_toks
>
self
.
max_length
:
if
max_gen_toks
>
self
.
max_length
:
# some model have low max length limit
max_gen_toks
=
self
.
max_gen_toks
max_gen_toks
=
self
.
max_gen_toks
else
:
else
:
max_gen_toks
=
self
.
max_gen_toks
max_gen_toks
=
self
.
max_gen_toks
...
@@ -1287,6 +1278,8 @@ class HFLM(TemplateLM):
...
@@ -1287,6 +1278,8 @@ class HFLM(TemplateLM):
# set the max length in tokens of inputs ("context_enc")
# set the max length in tokens of inputs ("context_enc")
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
# max len for inputs = max length, minus room to generate the max new tokens
# max len for inputs = max length, minus room to generate the max new tokens
# if the max new tokens is too large, halve it until it fits as we cannot change
# the max model length
max_ctx_len
=
self
.
max_length
-
max_gen_toks
max_ctx_len
=
self
.
max_length
-
max_gen_toks
while
max_ctx_len
<=
0
:
while
max_ctx_len
<=
0
:
max_gen_toks
=
max_gen_toks
//
2
max_gen_toks
=
max_gen_toks
//
2
...
...
lm_eval/models/utils.py
View file @
ff1b649e
...
@@ -389,7 +389,7 @@ class Collator:
...
@@ -389,7 +389,7 @@ class Collator:
self
.
_arr_with_indices
,
fn
=
self
.
_group_fn
,
group_by
=
"contexts"
self
.
_arr_with_indices
,
fn
=
self
.
_group_fn
,
group_by
=
"contexts"
)
)
def
get_batched
(
self
,
n
:
int
=
1
,
batch_fn
:
Optional
[
Callable
]
=
None
,
reset_batch_fn
:
Optional
[
Callable
]
=
None
,
accelerator
=
None
)
->
Iterator
:
def
get_batched
(
self
,
n
:
int
=
1
,
batch_fn
:
Optional
[
Callable
]
=
None
,
reset_batch_fn
:
Optional
[
Callable
]
=
None
)
->
Iterator
:
"""
"""
Generates and yields batches from the reordered array. The method of grouping and batching
Generates and yields batches from the reordered array. The method of grouping and batching
depends on the parameter `group_by`.
depends on the parameter `group_by`.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment