Commit 8db0a470 authored by lintangsutawika's avatar lintangsutawika
Browse files

update

parent 1b9deaab
import copy
from typing import List, Optional, Tuple, Union
import torch
import transformers
from tqdm import tqdm
......@@ -212,7 +213,15 @@ class HFMultimodalLM(HFLM):
### Up to here: was identical to non-multimodal HFLM generate_until ###
for chunk in chunks:
for idx, _chunk in enumerate(chunks):
if idx == 0:
zero_chunk = _chunk
chunk = _chunk
elif idx == 69:
chunk = zero_chunk
else:
chunk = _chunk
chunk = _chunk
contexts, all_gen_kwargs, aux_arguments = zip(
*chunk
) # TODO: can we cut down further on number of distinct things we pass around?
......@@ -264,7 +273,7 @@ class HFMultimodalLM(HFLM):
self.device, self.model.dtype
) # TODO: factor out into a tok_batch_encode bit ; truncate from left using max_ctx_len
print(inputs)
# print(inputs)
context_enc = inputs["input_ids"]
......@@ -272,7 +281,8 @@ class HFMultimodalLM(HFLM):
kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
cont = self._model_generate(inputs, stop=until, **gen_kwargs)
# del inputs
# del _chunk
### essentially same as HFLM beyond this line!
cont_toks_list = cont.tolist()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment