# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
"""Create a DataProto from a dict of tensors. This assumes that
"""Create a DataProto from a dict of tensors. This assumes that
1. All the tensor in tensors have the same dim0
1. All the tensor in tensors have the same dim0
2. Only dim0 is the batch dim
2. Only dim0 is the batch dim
...
@@ -293,13 +301,14 @@ class DataProto:
...
@@ -293,13 +301,14 @@ class DataProto:
else:
else:
current_batch=tensor.shape[:num_batch_dims]
current_batch=tensor.shape[:num_batch_dims]
assertbatch_size==current_batch,(
assertbatch_size==current_batch,(
f"Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. Got {pivot_key} has {batch_size}, {key} has {current_batch}"
f"Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. "
f"Got {pivot_key} has {batch_size}, {key} has {current_batch}"