Unverified Commit 02a5274b authored by Chaitanya Sri Krishna Lolla's avatar Chaitanya Sri Krishna Lolla Committed by GitHub
Browse files

Enable support for sparse tensors for multi_tensor_apply (#6)

parent 2d0f9cf2
...@@ -56,7 +56,7 @@ void multi_tensor_apply( ...@@ -56,7 +56,7 @@ void multi_tensor_apply(
for(int t = 0; t < tensor_lists[l].size(); t++) for(int t = 0; t < tensor_lists[l].size(); t++)
{ {
// TODO: Print which tensor fails. // TODO: Print which tensor fails.
bool contiguous_memory = tensor_lists[l][t].is_contiguous(); bool contiguous_memory = (tensor_lists[l][t].is_sparse()) ? tensor_lists[l][t]._values().is_contiguous() : tensor_lists[l][t].is_contiguous();
#ifdef VERSION_GE_1_5 #ifdef VERSION_GE_1_5
contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast)); contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
#endif #endif
...@@ -78,8 +78,15 @@ void multi_tensor_apply( ...@@ -78,8 +78,15 @@ void multi_tensor_apply(
for(int t = 0; t < ntensors; t++) for(int t = 0; t < ntensors; t++)
{ {
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
for(int d = 0; d < depth; d++) for(int d = 0; d < depth; d++) {
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); if (tensor_lists[d][t].is_sparse()) {
at::Tensor dst = at::zeros(tensor_lists[d][t].sizes(), tensor_lists[d][t].options().layout(at::kStrided));
dst.add_(tensor_lists[d][t]);
tl.addresses[d][loc_tensor_info] = dst.data_ptr();
} else {
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
}
}
loc_tensor_info++; loc_tensor_info++;
int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size; int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment