"git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "61594507ed4f68b8eaa09ede1c6c0e7081c62b39"
Commit 8db0a470 authored by lintangsutawika's avatar lintangsutawika
Browse files

update

parent 1b9deaab
import copy import copy
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch
import transformers import transformers
from tqdm import tqdm from tqdm import tqdm
...@@ -212,7 +213,15 @@ class HFMultimodalLM(HFLM): ...@@ -212,7 +213,15 @@ class HFMultimodalLM(HFLM):
### Up to here: was identical to non-multimodal HFLM generate_until ### ### Up to here: was identical to non-multimodal HFLM generate_until ###
for chunk in chunks: for idx, _chunk in enumerate(chunks):
if idx == 0:
zero_chunk = _chunk
chunk = _chunk
elif idx == 69:
chunk = zero_chunk
else:
chunk = _chunk
chunk = _chunk
contexts, all_gen_kwargs, aux_arguments = zip( contexts, all_gen_kwargs, aux_arguments = zip(
*chunk *chunk
) # TODO: can we cut down further on number of distinct things we pass around? ) # TODO: can we cut down further on number of distinct things we pass around?
...@@ -264,7 +273,7 @@ class HFMultimodalLM(HFLM): ...@@ -264,7 +273,7 @@ class HFMultimodalLM(HFLM):
self.device, self.model.dtype self.device, self.model.dtype
) # TODO: factor out into a tok_batch_encode bit ; truncate from left using max_ctx_len ) # TODO: factor out into a tok_batch_encode bit ; truncate from left using max_ctx_len
print(inputs) # print(inputs)
context_enc = inputs["input_ids"] context_enc = inputs["input_ids"]
...@@ -272,7 +281,8 @@ class HFMultimodalLM(HFLM): ...@@ -272,7 +281,8 @@ class HFMultimodalLM(HFLM):
kwargs["max_length"] = context_enc.shape[1] + max_gen_toks kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
cont = self._model_generate(inputs, stop=until, **gen_kwargs) cont = self._model_generate(inputs, stop=until, **gen_kwargs)
# del inputs
# del _chunk
### essentially same as HFLM beyond this line! ### essentially same as HFLM beyond this line!
cont_toks_list = cont.tolist() cont_toks_list = cont.tolist()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment