Unverified Commit 2d11f2e5 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

[API] left truncate for generate_until (#2554)

* left truncate for generate_until

* pre-commit
parent bcb4cbf4
...@@ -209,7 +209,7 @@ Not supported yet: multi-node evaluation and combinations of data replication wi ...@@ -209,7 +209,7 @@ Not supported yet: multi-node evaluation and combinations of data replication wi
Pipeline parallelizm during evaluation is supported with OpenVINO models Pipeline parallelizm during evaluation is supported with OpenVINO models
To enable pipeline parallelism, set the `model_args` of `pipeline_parallel`. In addition, you also have to set up `device` to value `HETERO:<GPU index1>,<GPU index2>` for example `HETERO:GPU.1,GPU.0` For example, the command to use pipeline paralelism of 2 is: To enable pipeline parallelism, set the `model_args` of `pipeline_parallel`. In addition, you also have to set up `device` to value `HETERO:<GPU index1>,<GPU index2>` for example `HETERO:GPU.1,GPU.0` For example, the command to use pipeline parallelism of 2 is:
``` ```
lm_eval --model openvino \ lm_eval --model openvino \
......
...@@ -448,9 +448,13 @@ class TemplateAPI(TemplateLM): ...@@ -448,9 +448,13 @@ class TemplateAPI(TemplateLM):
for chunk in chunks: for chunk in chunks:
for cache_key, context_enc, continuation_enc in chunk: for cache_key, context_enc, continuation_enc in chunk:
# max_length - 1 as we always have 1 token for generation # max_length - 1 as we always have 1 token for generation
inp = (context_enc + continuation_enc)[-(self.max_length) :] inp = (context_enc + continuation_enc)[-self.max_length :]
if len(inp) < len(context_enc + continuation_enc):
eval_logger.warning(
f"Context length ({len(context_enc)}) + continuation length ({len(continuation_enc)}) > max_length ({self.max_length}). Left truncating context."
)
ctxlen = len(context_enc) - max( ctxlen = len(context_enc) - max(
0, len(context_enc) + len(continuation_enc) - (self.max_length) 0, len(context_enc) + len(continuation_enc) - self.max_length
) )
inputs.append(inp) inputs.append(inp)
...@@ -594,6 +598,24 @@ class TemplateAPI(TemplateLM): ...@@ -594,6 +598,24 @@ class TemplateAPI(TemplateLM):
pbar = tqdm(desc="Requesting API", total=len(requests)) pbar = tqdm(desc="Requesting API", total=len(requests))
for chunk in chunked: for chunk in chunked:
contexts, all_gen_kwargs, encodings_list = zip(*chunk) contexts, all_gen_kwargs, encodings_list = zip(*chunk)
if self.tokenized_requests:
max_gen_toks = all_gen_kwargs[0].get(
"max_gen_toks", self._max_gen_toks
)
max_context_len = self.max_length - max_gen_toks
encodings_list = [x[-max_context_len:] for x in encodings_list]
if any(
len(x) + max_gen_toks > self.max_length for x in encodings_list
):
eval_logger.warning(
f"Some contexts exceeded (max length: ({self.max_length}) - max_gen_toks: ({max_gen_toks}). They were left truncated."
)
else:
eval_logger.info(
"Tokenized requests are disabled. Context + generation length is not checked."
)
req = encodings_list if self.tokenized_requests else contexts req = encodings_list if self.tokenized_requests else contexts
outputs = retry( outputs = retry(
stop=stop_after_attempt(self.max_retries), stop=stop_after_attempt(self.max_retries),
...@@ -625,6 +647,24 @@ class TemplateAPI(TemplateLM): ...@@ -625,6 +647,24 @@ class TemplateAPI(TemplateLM):
else: else:
for chunk in chunked: for chunk in chunked:
contexts, all_gen_kwargs, encodings_list = zip(*chunk) contexts, all_gen_kwargs, encodings_list = zip(*chunk)
if self.tokenized_requests:
max_gen_toks = all_gen_kwargs[0].get(
"max_gen_toks", self._max_gen_toks
)
max_context_len = self.max_length - max_gen_toks
encodings_list = [x[-max_context_len:] for x in encodings_list]
if any(
len(x) + max_gen_toks > self.max_length for x in encodings_list
):
eval_logger.warning(
f"Some contexts exceeded (max length: ({self.max_length}) - max_gen_toks ({max_gen_toks}). They were left truncated."
)
else:
eval_logger.info(
"Tokenized requests are disabled. Context + generation length is not checked."
)
req = encodings_list if self.tokenized_requests else contexts req = encodings_list if self.tokenized_requests else contexts
results = itertools.chain.from_iterable( results = itertools.chain.from_iterable(
asyncio.run( asyncio.run(
......
...@@ -71,9 +71,11 @@ class OptimumLM(HFLM): ...@@ -71,9 +71,11 @@ class OptimumLM(HFLM):
else: else:
model_kwargs["ov_config"] = {} model_kwargs["ov_config"] = {}
model_kwargs["ov_config"].setdefault("CACHE_DIR", "") model_kwargs["ov_config"].setdefault("CACHE_DIR", "")
if 'pipeline_parallel' in model_kwargs: if "pipeline_parallel" in model_kwargs:
if model_kwargs['pipeline_parallel']: if model_kwargs["pipeline_parallel"]:
model_kwargs["ov_config"]["MODEL_DISTRIBUTION_POLICY"] = "PIPELINE_PARALLEL" model_kwargs["ov_config"]["MODEL_DISTRIBUTION_POLICY"] = (
"PIPELINE_PARALLEL"
)
model_file = Path(pretrained) / "openvino_model.xml" model_file = Path(pretrained) / "openvino_model.xml"
if model_file.exists(): if model_file.exists():
export = False export = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment