Commit f72a48dd authored by Boris Bonev's avatar Boris Bonev
Browse files

fixing more issues from merge

parent 828f2ba3
......@@ -76,7 +76,7 @@ def _split_distributed_convolution_tensor_s2(
):
"""
Splits a pre-computed convolution tensor along the latitude dimension for distributed processing.
This function takes a convolution tensor that was generated by the serial routine and filters
it to only include entries corresponding to the local latitude slice assigned to this process.
The filtering is done based on the polar group rank and the computed split shapes.
......@@ -100,11 +100,6 @@ def _split_distributed_convolution_tensor_s2(
Filtered values corresponding to the local latitude slice
"""
assert len(in_shape) == 2
assert len(out_shape) == 2
kernel_size = filter_basis.kernel_size
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
......@@ -154,7 +149,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
groups: Optional[int]
Number of groups
grid_in: Optional[str]
Grid type for the input tensor
Grid type for the input tensor
grid_out: Optional[str]
Grid type for the output tensor
bias: Optional[bool]
......@@ -327,7 +322,7 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
groups: Optional[int]
Number of groups
grid_in: Optional[str]
Grid type for the input tensor
Grid type for the input tensor
grid_out: Optional[str]
Grid type for the output tensor
bias: Optional[bool]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment