Commit f671f133 authored by WenmuZhou's avatar WenmuZhou
Browse files

move chunk to infer model

parent c703a589
......@@ -94,8 +94,11 @@ class VQATokenPad(object):
'input_ids', 'labels', 'token_type_ids', 'bbox',
'attention_mask'
]:
if self.infer_mode and key == 'labels':
continue
if self.infer_mode:
if key != 'labels':
length = min(len(data[key]), self.max_seq_len)
data[key] = np.array(data[key][:length], dtype='int64')
data[key] = data[key][:length]
else:
continue
data[key] = np.array(data[key], dtype='int64')
return data
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment