Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
4ef930c1
Unverified
Commit
4ef930c1
authored
Aug 14, 2020
by
mcarilli
Committed by
GitHub
Aug 14, 2020
Browse files
Should pass stricter stride/size checks in pytorch (#942)
parent
5d9b5cbc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
4 deletions
+4
-4
apex/parallel/optimized_sync_batchnorm_kernel.py
apex/parallel/optimized_sync_batchnorm_kernel.py
+4
-4
No files found.
apex/parallel/optimized_sync_batchnorm_kernel.py
View file @
4ef930c1
...
...
@@ -33,11 +33,11 @@ class SyncBatchnormFunction(Function):
mean_all
=
torch
.
empty
(
world_size
,
mean
.
size
(
0
),
dtype
=
mean
.
dtype
,
device
=
device
)
var_all
=
torch
.
empty
(
world_size
,
var_biased
.
size
(
0
),
dtype
=
var_biased
.
dtype
,
device
=
device
)
count_all
=
torch
.
cuda
.
IntTensor
(
world_size
,
device
=
device
)
mean_l
=
[
mean_all
.
narrow
(
0
,
i
,
1
)
for
i
in
range
(
world_size
)]
var_l
=
[
var_all
.
narrow
(
0
,
i
,
1
)
for
i
in
range
(
world_size
)]
mean_l
=
[
mean_all
.
narrow
(
0
,
i
,
1
)
.
view
(
-
1
)
for
i
in
range
(
world_size
)]
var_l
=
[
var_all
.
narrow
(
0
,
i
,
1
)
.
view
(
-
1
)
for
i
in
range
(
world_size
)]
count_l
=
[
count_all
.
narrow
(
0
,
i
,
1
)
for
i
in
range
(
world_size
)]
torch
.
distributed
.
all_gather
(
mean_l
,
mean
,
process_group
)
torch
.
distributed
.
all_gather
(
var_l
,
var_biased
,
process_group
)
torch
.
distributed
.
all_gather
(
mean_l
,
mean
.
view
(
-
1
)
,
process_group
)
torch
.
distributed
.
all_gather
(
var_l
,
var_biased
.
view
(
-
1
)
,
process_group
)
torch
.
distributed
.
all_gather
(
count_l
,
torch
.
cuda
.
IntTensor
([
count
],
device
=
device
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment