Commit ba2407e2 authored by rohithkrn's avatar rohithkrn
Browse files

Merge remote-tracking branch 'rocm_up/master' into apex_amp_bfp15

parents d283f97f 02a5274b
......@@ -56,7 +56,7 @@ void multi_tensor_apply(
for(int t = 0; t < tensor_lists[l].size(); t++)
{
// TODO: Print which tensor fails.
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
bool contiguous_memory = (tensor_lists[l][t].is_sparse()) ? tensor_lists[l][t]._values().is_contiguous() : tensor_lists[l][t].is_contiguous();
#ifdef VERSION_GE_1_5
contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
#endif
......@@ -78,8 +78,15 @@ void multi_tensor_apply(
for(int t = 0; t < ntensors; t++)
{
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
for(int d = 0; d < depth; d++)
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
for(int d = 0; d < depth; d++) {
if (tensor_lists[d][t].is_sparse()) {
at::Tensor dst = at::zeros(tensor_lists[d][t].sizes(), tensor_lists[d][t].options().layout(at::kStrided));
dst.add_(tensor_lists[d][t]);
tl.addresses[d][loc_tensor_info] = dst.data_ptr();
} else {
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
}
}
loc_tensor_info++;
int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment