# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
"""Create a DataProto from a dict of tensors. This assumes that
1. All the tensor in tensors have the same dim0
2. Only dim0 is the batch dim
...
...
@@ -293,13 +301,14 @@ class DataProto:
else:
current_batch=tensor.shape[:num_batch_dims]
assertbatch_size==current_batch,(
f"Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. Got {pivot_key} has {batch_size}, {key} has {current_batch}"
f"Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. "
f"Got {pivot_key} has {batch_size}, {key} has {current_batch}"