Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
1e5f7a2f
Commit
1e5f7a2f
authored
Dec 16, 2023
by
Boris Bonev
Browse files
adjusting initialization
parent
9577cc8f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
2 deletions
+4
-2
torch_harmonics/s2_convolutions.py
torch_harmonics/s2_convolutions.py
+4
-2
No files found.
torch_harmonics/s2_convolutions.py
View file @
1e5f7a2f
...
@@ -208,7 +208,8 @@ class DiscreteContinuousConvS2(nn.Module):
...
@@ -208,7 +208,8 @@ class DiscreteContinuousConvS2(nn.Module):
if
out_channels
%
self
.
groups
!=
0
:
if
out_channels
%
self
.
groups
!=
0
:
raise
ValueError
(
"Error, the number of output channels has to be an integer multiple of the group size"
)
raise
ValueError
(
"Error, the number of output channels has to be an integer multiple of the group size"
)
self
.
groupsize
=
in_channels
//
self
.
groups
self
.
groupsize
=
in_channels
//
self
.
groups
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
out_channels
,
self
.
groupsize
,
kernel_shape
[
0
]))
scale
=
math
.
sqrt
(
1.0
/
self
.
groupsize
)
self
.
weight
=
nn
.
Parameter
(
scale
*
torch
.
randn
(
out_channels
,
self
.
groupsize
,
kernel_shape
[
0
]))
if
bias
:
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
out_channels
))
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
out_channels
))
...
@@ -299,7 +300,8 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
...
@@ -299,7 +300,8 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
if
out_channels
%
self
.
groups
!=
0
:
if
out_channels
%
self
.
groups
!=
0
:
raise
ValueError
(
"Error, the number of output channels has to be an integer multiple of the group size"
)
raise
ValueError
(
"Error, the number of output channels has to be an integer multiple of the group size"
)
self
.
groupsize
=
in_channels
//
self
.
groups
self
.
groupsize
=
in_channels
//
self
.
groups
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
out_channels
,
self
.
groupsize
,
kernel_shape
[
0
]))
scale
=
math
.
sqrt
(
1.0
/
self
.
groupsize
)
self
.
weight
=
nn
.
Parameter
(
scale
*
torch
.
randn
(
out_channels
,
self
.
groupsize
,
kernel_shape
[
0
]))
if
bias
:
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
out_channels
))
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
out_channels
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment