Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
3dd36070
Unverified
Commit
3dd36070
authored
Jul 20, 2020
by
Asit
Committed by
GitHub
Jul 20, 2020
Browse files
Merge pull request #1 from NVIDIA/master
Updating my repo
parents
02a33875
3104fd59
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
595 additions
and
99 deletions
+595
-99
apex/amp/compat.py
apex/amp/compat.py
+4
-0
apex/amp/lists/tensor_overrides.py
apex/amp/lists/tensor_overrides.py
+6
-6
apex/contrib/multihead_attn/encdec_multihead_attn.py
apex/contrib/multihead_attn/encdec_multihead_attn.py
+7
-1
apex/contrib/multihead_attn/self_multihead_attn.py
apex/contrib/multihead_attn/self_multihead_attn.py
+7
-1
apex/contrib/sparsity/asp.py
apex/contrib/sparsity/asp.py
+3
-3
apex/contrib/sparsity/sparse_masklib.py
apex/contrib/sparsity/sparse_masklib.py
+19
-5
apex/parallel/optimized_sync_batchnorm_kernel.py
apex/parallel/optimized_sync_batchnorm_kernel.py
+21
-15
csrc/syncbn.cpp
csrc/syncbn.cpp
+11
-9
csrc/welford.cu
csrc/welford.cu
+81
-51
tests/L0/run_optimizers/test_lamb.py
tests/L0/run_optimizers/test_lamb.py
+259
-0
tests/distributed/synced_batchnorm/single_gpu_unit_test.py
tests/distributed/synced_batchnorm/single_gpu_unit_test.py
+7
-4
tests/distributed/synced_batchnorm/two_gpu_test_different_batch_size.py
...ted/synced_batchnorm/two_gpu_test_different_batch_size.py
+158
-0
tests/distributed/synced_batchnorm/two_gpu_unit_test.py
tests/distributed/synced_batchnorm/two_gpu_unit_test.py
+11
-4
tests/distributed/synced_batchnorm/unit_test.sh
tests/distributed/synced_batchnorm/unit_test.sh
+1
-0
No files found.
apex/amp/compat.py
View file @
3dd36070
...
...
@@ -40,3 +40,7 @@ def scalar_python_val(x):
return
x
.
data
[
0
]
else
:
return
x
[
0
]
# Accounts for the possibility that some ops may be removed from a namespace.
def
filter_attrs
(
module
,
attrs
):
return
list
(
attrname
for
attrname
in
attrs
if
hasattr
(
module
,
attrname
))
apex/amp/lists/tensor_overrides.py
View file @
3dd36070
...
...
@@ -11,20 +11,20 @@ MODULE = torch.Tensor
# MODULE = torch.autograd.Variable
FP16_FUNCS
=
[
FP16_FUNCS
=
compat
.
filter_attrs
(
MODULE
,
[
'__matmul__'
,
]
]
)
FP32_FUNCS
=
[
FP32_FUNCS
=
compat
.
filter_attrs
(
MODULE
,
[
'__ipow__'
,
'__pow__'
,
'__rpow__'
,
# Cast to fp32 before transfer to CPU
'cpu'
,
]
]
)
CASTS
=
[
CASTS
=
compat
.
filter_attrs
(
MODULE
,
[
'__add__'
,
'__div__'
,
'__eq__'
,
...
...
@@ -46,7 +46,7 @@ CASTS = [
'__rtruediv__'
,
'__sub__'
,
'__truediv__'
,
]
]
)
# None of these, but here to make code cleaner.
SEQUENCE_CASTS
=
[]
...
...
apex/contrib/multihead_attn/encdec_multihead_attn.py
View file @
3dd36070
import
math
import
torch
from
torch
import
nn
from
torch.nn
import
Parameter
...
...
@@ -76,7 +78,11 @@ class EncdecMultiheadAttn(nn.Module):
def
reset_parameters
(
self
):
nn
.
init
.
xavier_uniform_
(
self
.
in_proj_weight_q
)
nn
.
init
.
xavier_uniform_
(
self
.
in_proj_weight_kv
)
# in_proj_weight_kv has shape [2 * hidden, hidden] but it should be
# initialized like a [hidden, hidden] matrix.
# sqrt(6 / (hidden + hidden)) / sqrt(6 / (2 * hidden + hidden)) = sqrt(1.5)
# therefore xavier_uniform gain should be set to sqrt(1.5).
nn
.
init
.
xavier_uniform_
(
self
.
in_proj_weight_kv
,
gain
=
math
.
sqrt
(
1.5
))
nn
.
init
.
xavier_uniform_
(
self
.
out_proj_weight
)
if
self
.
bias
:
nn
.
init
.
constant_
(
self
.
in_proj_bias_q
,
0.
)
...
...
apex/contrib/multihead_attn/self_multihead_attn.py
View file @
3dd36070
import
math
import
torch
from
torch
import
nn
from
torch.nn
import
Parameter
...
...
@@ -98,7 +100,11 @@ class SelfMultiheadAttn(nn.Module):
nn
.
init
.
xavier_uniform_
(
self
.
k_weight
)
nn
.
init
.
xavier_uniform_
(
self
.
v_weight
)
else
:
nn
.
init
.
xavier_uniform_
(
self
.
in_proj_weight
)
# in_proj_weight has shape [3 * hidden, hidden] but it should be
# initialized like a [hidden, hidden] matrix.
# sqrt(6 / (hidden + hidden)) / sqrt(6 / (3 * hidden + hidden)) = sqrt(2)
# therefore xavier_uniform gain should be set to sqrt(2).
nn
.
init
.
xavier_uniform_
(
self
.
in_proj_weight
,
gain
=
math
.
sqrt
(
2
))
nn
.
init
.
xavier_uniform_
(
self
.
out_proj_weight
)
if
self
.
bias
:
if
self
.
separate_qkv_params
:
...
...
apex/contrib/sparsity/asp.py
View file @
3dd36070
...
...
@@ -6,7 +6,7 @@ torchvision_imported=True
try
:
import
torchvision
except
ImportError
:
print
(
"[ASP][Warning] torchvision cannot be imported
, may infuence functionality of MaskRCNN/KeypointRCNN network from torchvision
."
)
print
(
"[ASP][Warning] torchvision cannot be imported."
)
torchvision_imported
=
False
def
eligible_modules
(
model
,
whitelist_layer_types
,
allowed_layer_names
,
disallowed_layer_names
):
...
...
@@ -78,7 +78,7 @@ class ASP:
# function to extract variables that will be sparsified.
# idea is that you will add one of these functions for each module type that can be sparsified.
if
torchvision_imported
:
print
(
"[ASP] torchvision is imported, can work
smoothly
with the MaskRCNN/KeypointRCNN from torchvision."
)
print
(
"[ASP] torchvision is imported, can work with the MaskRCNN/KeypointRCNN from torchvision."
)
sparse_parameter_list
=
{
torch
.
nn
.
Linear
:
[
'weight'
],
torch
.
nn
.
Conv1d
:
[
'weight'
],
torch
.
nn
.
Conv2d
:
[
'weight'
],
torch
.
nn
.
Conv3d
:
[
'weight'
],
torchvision
.
ops
.
misc
.
Conv2d
:
[
'weight'
]}
else
:
sparse_parameter_list
=
{
torch
.
nn
.
Linear
:
[
'weight'
],
torch
.
nn
.
Conv1d
:
[
'weight'
],
torch
.
nn
.
Conv2d
:
[
'weight'
],
torch
.
nn
.
Conv3d
:
[
'weight'
]}
...
...
@@ -102,7 +102,7 @@ class ASP:
print
(
"[ASP] Sparsifying %s::%s of size=%s and type=%s for sparsity"
%
(
module_name
,
p_name
,
str
(
p
.
size
()),
str
(
p
.
dtype
)))
mask
=
torch
.
ones_like
(
p
).
bool
()
buffname
=
name
.
split
(
"."
)[
-
1
]
# buffer names cannot contain "."
buffname
=
p_
name
.
split
(
"."
)[
-
1
]
# buffer names cannot contain "."
module
.
register_buffer
(
'__%s_mma_mask'
%
buffname
,
mask
)
if
allow_recompute_mask
:
pruned
=
torch
.
zeros_like
(
p
).
cpu
()
...
...
apex/contrib/sparsity/sparse_masklib.py
View file @
3dd36070
...
...
@@ -5,7 +5,7 @@ import collections
from
itertools
import
permutations
""" compute density (helper fn to compute % NNZs in a tensor)"""
""" compute density (helper fn to compute % NNZs in a tensor)
"""
def
fill
(
x
):
return
float
(
x
.
nonzero
().
size
(
0
))
/
torch
.
numel
(
x
)
...
...
@@ -20,7 +20,7 @@ def reshape_1d(matrix, m):
else
:
return
matrix
.
view
(
-
1
,
m
),
matrix
.
shape
""" return all possible m:n patterns in a 1d vector
.
"""
""" return all possible m:n patterns in a 1d vector """
valid_m4n2_1d_patterns
=
None
def
compute_valid_1d_patterns
(
m
,
n
):
# Early exit if patterns was already created.
...
...
@@ -49,8 +49,21 @@ def mn_1d_best(matrix, m, n):
def
m4n2_1d
(
mat
,
density
):
return
mn_1d_best
(
mat
,
4
,
2
)
""" Comment: Following 2d masking related code (for training) can be removed or marked experimental (78 LOC) """
""" m:n 2d structured greedy """
"""
Below 2d-masking related code is targeted more for training (from scratch).
2d-pruning of a weight tensor is done to accelerate DGRAD step during backprop
phase of training algorithm. Acceleration comes from using SpMMA instructions in
Tensor Cores of NVIDIA Ampere GPU Architecture
(note: this code does not do the acceleration, GPU kernels are required for this).
1d pruning of weight tensor helps speed up FPROP step by pruning in 2:4 pattern
along the horizontal (logical) direction.
During DGRAD step, weight tensor is transposed. 2d pruning functions below, mask
weight tensor such that their transposed versions are also 2:4 sparse along the
horizontal (logical) direction. Thus, with 2d pruning, weight tensors are
2:4 sparse along row and column directions.
"""
""" m:n 2d structured pruning: greedy method to select mask """
def
mn_2d_greedy
(
matrix
,
m
,
n
):
# Convert to numpy
mat
=
matrix
.
cpu
().
detach
().
numpy
()
...
...
@@ -105,7 +118,7 @@ def compute_valid_2d_patterns(m,n):
if
m
==
4
and
n
==
2
:
valid_m4n2_2d_patterns
=
valid_patterns
return
valid_patterns
""" m:n 2d structured
best
"""
""" m:n 2d structured
pruning: exhaustive method to select best mask
"""
def
mn_2d_best
(
matrix
,
m
,
n
):
# Find all possible patterns.
patterns
=
compute_valid_2d_patterns
(
m
,
n
).
cuda
()
...
...
@@ -127,6 +140,7 @@ def mn_2d_best(matrix, m, n):
def
m4n2_2d_best
(
mat
,
density
):
return
mn_2d_best
(
mat
,
4
,
2
)
""" returns a sparse mask """
def
create_mask
(
tensor
,
pattern
=
"m4n2_1d"
,
density
=
0.5
):
# Reshape tensor and mask.
...
...
apex/parallel/optimized_sync_batchnorm_kernel.py
View file @
3dd36070
...
...
@@ -28,16 +28,24 @@ class SyncBatchnormFunction(Function):
if
torch
.
distributed
.
is_initialized
():
if
not
process_group
:
process_group
=
torch
.
distributed
.
group
.
WORLD
device
=
mean
.
device
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
mean_all
=
torch
.
empty
(
world_size
,
mean
.
size
(
0
),
dtype
=
mean
.
dtype
,
device
=
mean
.
device
)
var_all
=
torch
.
empty
(
world_size
,
var_biased
.
size
(
0
),
dtype
=
var_biased
.
dtype
,
device
=
var_biased
.
device
)
mean_all
=
torch
.
empty
(
world_size
,
mean
.
size
(
0
),
dtype
=
mean
.
dtype
,
device
=
device
)
var_all
=
torch
.
empty
(
world_size
,
var_biased
.
size
(
0
),
dtype
=
var_biased
.
dtype
,
device
=
device
)
count_all
=
torch
.
cuda
.
IntTensor
(
world_size
,
device
=
device
)
mean_l
=
[
mean_all
.
narrow
(
0
,
i
,
1
)
for
i
in
range
(
world_size
)]
var_l
=
[
var_all
.
narrow
(
0
,
i
,
1
)
for
i
in
range
(
world_size
)]
count_l
=
[
count_all
.
narrow
(
0
,
i
,
1
)
for
i
in
range
(
world_size
)]
torch
.
distributed
.
all_gather
(
mean_l
,
mean
,
process_group
)
torch
.
distributed
.
all_gather
(
var_l
,
var_biased
,
process_group
)
mean
,
var
,
inv_std
=
syncbn
.
welford_parallel
(
mean_all
,
var_all
,
count
,
eps
)
# TODO(Jie): should do fp32 math instead!
torch
.
distributed
.
all_gather
(
count_l
,
torch
.
cuda
.
IntTensor
([
count
],
device
=
device
),
process_group
)
mean
,
var
,
inv_std
=
syncbn
.
welford_parallel
(
mean_all
,
var_all
,
count_all
,
eps
)
else
:
device
=
mean
.
device
count_all
=
torch
.
cuda
.
IntTensor
([
count
],
device
=
device
)
inv_std
=
1.0
/
torch
.
sqrt
(
var_biased
+
eps
)
var
=
var_biased
*
(
count
)
/
(
count
-
1
)
...
...
@@ -52,7 +60,7 @@ class SyncBatchnormFunction(Function):
mean
=
running_mean
.
data
inv_std
=
1.0
/
torch
.
sqrt
(
running_variance
.
data
+
eps
)
ctx
.
save_for_backward
(
input
,
weight
,
mean
,
inv_std
,
z
,
bias
)
ctx
.
save_for_backward
(
input
,
weight
,
mean
,
inv_std
,
z
,
bias
,
count_all
)
ctx
.
process_group
=
process_group
ctx
.
channel_last
=
channel_last
ctx
.
world_size
=
world_size
...
...
@@ -71,7 +79,7 @@ class SyncBatchnormFunction(Function):
# mini batch mean & var are calculated by forward path.
# mu = 1./N*np.sum(h, axis = 0)
# var = 1./N*np.sum((h-mu)**2, axis = 0)
saved_input
,
weight
,
mean
,
inv_std
,
z
,
bias
=
ctx
.
saved_tensors
saved_input
,
weight
,
mean
,
inv_std
,
z
,
bias
,
count
=
ctx
.
saved_tensors
process_group
=
ctx
.
process_group
channel_last
=
ctx
.
channel_last
world_size
=
ctx
.
world_size
...
...
@@ -83,26 +91,24 @@ class SyncBatchnormFunction(Function):
if
isinstance
(
z
,
torch
.
Tensor
)
and
ctx
.
needs_input_grad
[
1
]:
grad_z
=
grad_output
.
clone
()
# TODO
(jie): why do I have to clon
e
h
er
e? life time of grad_output?
# TODO
: updat
e
k
er
nel to not pre_divide by item_num
if
channel_last
:
mean
_dy
,
mean
_dy_xmu
,
grad_weight
,
grad_bias
=
syncbn
.
reduce_bn_c_last
(
grad_output
,
saved_input
,
mean
,
inv_std
,
weight
)
sum
_dy
,
sum
_dy_xmu
,
grad_weight
,
grad_bias
=
syncbn
.
reduce_bn_c_last
(
grad_output
,
saved_input
,
mean
,
inv_std
,
weight
)
else
:
mean
_dy
,
mean
_dy_xmu
,
grad_weight
,
grad_bias
=
syncbn
.
reduce_bn
(
grad_output
,
saved_input
,
mean
,
inv_std
,
weight
)
sum
_dy
,
sum
_dy_xmu
,
grad_weight
,
grad_bias
=
syncbn
.
reduce_bn
(
grad_output
,
saved_input
,
mean
,
inv_std
,
weight
)
# calculate grad_input
if
ctx
.
needs_input_grad
[
0
]:
if
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
all_reduce
(
mean_dy
,
ReduceOp
.
SUM
,
process_group
)
mean_dy
=
mean_dy
/
world_size
sum_dy
,
ReduceOp
.
SUM
,
process_group
)
torch
.
distributed
.
all_reduce
(
mean_dy_xmu
,
ReduceOp
.
SUM
,
process_group
)
mean_dy_xmu
=
mean_dy_xmu
/
world_size
sum_dy_xmu
,
ReduceOp
.
SUM
,
process_group
)
if
channel_last
:
grad_input
=
syncbn
.
batchnorm_backward_c_last
(
grad_output
,
saved_input
,
mean
,
inv_std
,
weight
,
mean
_dy
,
mean
_dy_xmu
)
grad_input
=
syncbn
.
batchnorm_backward_c_last
(
grad_output
,
saved_input
,
mean
,
inv_std
,
weight
,
sum
_dy
,
sum
_dy_xmu
,
count
)
else
:
grad_input
=
syncbn
.
batchnorm_backward
(
grad_output
,
saved_input
,
mean
,
inv_std
,
weight
,
mean
_dy
,
mean
_dy_xmu
)
grad_input
=
syncbn
.
batchnorm_backward
(
grad_output
,
saved_input
,
mean
,
inv_std
,
weight
,
sum
_dy
,
sum
_dy_xmu
,
count
)
if
weight
is
None
or
not
ctx
.
needs_input_grad
[
2
]:
grad_weight
=
None
...
...
csrc/syncbn.cpp
View file @
3dd36070
...
...
@@ -12,7 +12,7 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input);
// implemented using welford
std
::
vector
<
at
::
Tensor
>
welford_parallel_CUDA
(
const
at
::
Tensor
mean_feature_nodes
,
const
at
::
Tensor
var_biased_feature_nodes
,
int
numel
,
const
at
::
Tensor
numel
,
const
float
eps
);
// elementwise BN operation, returns output
...
...
@@ -24,7 +24,7 @@ at::Tensor batchnorm_forward_CUDA(const at::Tensor input,
const
at
::
optional
<
at
::
Tensor
>
weight
,
const
at
::
optional
<
at
::
Tensor
>
shift
);
// backward BN operation, returns {
mean
_dy,
mean
_dy_xmu, grad_weight, grad_bias}
// backward BN operation, returns {
sum
_dy,
sum
_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// implemented using kahan summation
...
...
@@ -36,14 +36,15 @@ std::vector<at::Tensor> reduce_bn_CUDA(const at::Tensor grad_output,
// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
// mean/inv_std/
mean_dy/mean
_dy_xmu precision is fp32
// mean/inv_std/
sum_dy/sum
_dy_xmu precision is fp32
at
::
Tensor
batchnorm_backward_CUDA
(
const
at
::
Tensor
grad_output
,
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
inv_std
,
const
at
::
optional
<
at
::
Tensor
>
weight
,
const
at
::
Tensor
mean_dy
,
const
at
::
Tensor
mean_dy_xmu
);
const
at
::
Tensor
sum_dy
,
const
at
::
Tensor
sum_dy_xmu
,
const
at
::
Tensor
count
);
// returns {mean, biased_var}
// implemented using welford
...
...
@@ -62,7 +63,7 @@ at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input,
const
at
::
optional
<
at
::
Tensor
>
shift
,
const
bool
fuse_relu
);
// backward BN operation, returns {
mean
_dy,
mean
_dy_xmu, grad_weight, grad_bias}
// backward BN operation, returns {
sum
_dy,
sum
_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
...
...
@@ -74,15 +75,16 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(const at::Tensor grad_output,
// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
// mean/inv_std/
mean_dy/mean
_dy_xmu precision is fp32
// mean/inv_std/
sum_dy/sum
_dy_xmu precision is fp32
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
at
::
Tensor
batchnorm_backward_c_last_CUDA
(
const
at
::
Tensor
grad_output
,
const
at
::
Tensor
input
,
const
at
::
Tensor
mean
,
const
at
::
Tensor
inv_std
,
const
at
::
optional
<
at
::
Tensor
>
weight
,
const
at
::
Tensor
mean_dy
,
const
at
::
Tensor
mean_dy_xmu
);
const
at
::
Tensor
sum_dy
,
const
at
::
Tensor
sum_dy_xmu
,
const
at
::
Tensor
count
);
at
::
Tensor
relu_backward_c_last_CUDA
(
const
at
::
Tensor
grad_output
,
const
at
::
Tensor
input
,
...
...
csrc/welford.cu
View file @
3dd36070
...
...
@@ -327,15 +327,15 @@ __global__ void reduce_bn_kernel(
const
scalar_t
*
__restrict__
grad_output
,
const
accscalar_t
*
__restrict__
mean
,
const
accscalar_t
*
__restrict__
inv_std
,
accscalar_t
*
__restrict__
mean
_dy
,
accscalar_t
*
__restrict__
mean
_dy_xmu
,
accscalar_t
*
__restrict__
sum
_dy
_o
,
accscalar_t
*
__restrict__
sum
_dy_xmu
_o
,
layerscalar_t
*
__restrict__
grad_weight
,
layerscalar_t
*
__restrict__
grad_bias
,
const
int
bs
,
const
int
fs
,
const
int
ss
)
{
static
__shared__
int
s_mem
[
64
];
int
total_item_num
=
bs
*
ss
;
//
int total_item_num = bs * ss;
int
thread_id
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
...
...
@@ -377,8 +377,10 @@ __global__ void reduce_bn_kernel(
if
(
grad_weight
!=
NULL
)
{
grad_weight
[
blockIdx
.
x
]
=
static_cast
<
layerscalar_t
>
(
sum_dy_xmu
*
factor
);
}
mean_dy
[
blockIdx
.
x
]
=
sum_dy
/
total_item_num
;
mean_dy_xmu
[
blockIdx
.
x
]
=
sum_dy_xmu
/
total_item_num
;
//mean_dy[blockIdx.x] = sum_dy / total_item_num;
//mean_dy_xmu[blockIdx.x] = sum_dy_xmu / total_item_num;
sum_dy_o
[
blockIdx
.
x
]
=
sum_dy
;
sum_dy_xmu_o
[
blockIdx
.
x
]
=
sum_dy_xmu
;
}
}
...
...
@@ -390,16 +392,24 @@ __global__ void batchnorm_backward_kernel(
const
accscalar_t
*
__restrict__
mean
,
const
accscalar_t
*
__restrict__
inv_std
,
const
layerscalar_t
*
__restrict__
weight
,
const
accscalar_t
*
__restrict__
mean_dy
,
const
accscalar_t
*
__restrict__
mean_dy_xmu
,
const
accscalar_t
*
__restrict__
sum_dy
,
const
accscalar_t
*
__restrict__
sum_dy_xmu
,
const
int
*
__restrict__
numel
,
scalar_t
*
__restrict__
grad_input
,
const
int64_t
world_size
,
const
int
ss
,
const
int
bs
)
{
int64_t
div
=
0
;
for
(
int
i
=
0
;
i
<
world_size
;
i
++
)
{
div
+=
numel
[
i
];
}
auto
m_c
=
static_cast
<
accscalar_t
>
(
mean
[
blockIdx
.
x
]);
auto
m_dy_c
=
static_cast
<
accscalar_t
>
(
mean_dy
[
blockIdx
.
x
]);
//auto m_dy_c = static_cast<accscalar_t>(mean_dy[blockIdx.x]);
auto
m_dy_c
=
static_cast
<
accscalar_t
>
(
sum_dy
[
blockIdx
.
x
])
/
div
;
auto
factor_1_c
=
inv_std
[
blockIdx
.
x
];
auto
factor_2_c
=
(
weight
==
NULL
?
accscalar_t
(
1.0
)
:
static_cast
<
accscalar_t
>
(
weight
[
blockIdx
.
x
]))
*
factor_1_c
;
factor_1_c
=
factor_1_c
*
factor_1_c
*
mean_dy_xmu
[
blockIdx
.
x
];
//factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[blockIdx.x];
factor_1_c
=
factor_1_c
*
factor_1_c
*
sum_dy_xmu
[
blockIdx
.
x
]
/
div
;
for
(
int
batch_offset
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
batch_offset
<
bs
;
batch_offset
+=
gridDim
.
y
*
blockDim
.
y
)
{
int
address_base
=
blockIdx
.
x
*
ss
+
batch_offset
*
gridDim
.
x
*
ss
;
...
...
@@ -559,13 +569,13 @@ template <typename scalar_t>
__global__
void
welford_kernel_parallel
(
const
scalar_t
*
__restrict__
mean
,
const
scalar_t
*
__restrict__
var_biased
,
const
int
*
__restrict__
numel
,
scalar_t
*
__restrict__
out_mean
,
scalar_t
*
__restrict__
out_var
,
scalar_t
*
__restrict__
inv_std
,
const
int
world_size
,
const
int
feature_size
,
const
float
eps
,
const
int
numel
)
{
const
float
eps
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
feature_size
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
// load data;
...
...
@@ -574,7 +584,7 @@ __global__ void welford_kernel_parallel(
scalar_t
m_2_n
=
0
;
int
count
=
0
;
for
(
int
j
=
0
;
j
<
world_size
;
j
++
)
{
welford_merge_element
(
count
,
x_mean
,
m_2_n
,
numel
,
mean
[
address
],
var_biased
[
address
]
*
numel
);
welford_merge_element
(
count
,
x_mean
,
m_2_n
,
numel
[
j
]
,
mean
[
address
],
var_biased
[
address
]
*
numel
[
j
]
);
address
+=
feature_size
;
}
out_mean
[
i
]
=
x_mean
;
...
...
@@ -694,8 +704,8 @@ __global__ void reduce_bn_c_last_kernel(
const
scalar_t
*
__restrict__
grad_output
,
const
accscalar_t
*
__restrict__
mean
,
const
accscalar_t
*
__restrict__
inv_std
,
accscalar_t
*
__restrict__
mean
_dy
,
accscalar_t
*
__restrict__
mean
_dy_xmu
,
accscalar_t
*
__restrict__
sum
_dy
_o
,
accscalar_t
*
__restrict__
sum
_dy_xmu
_o
,
layerscalar_t
*
__restrict__
grad_weight
,
layerscalar_t
*
__restrict__
grad_bias
,
volatile
accscalar_t
*
staging_data
,
...
...
@@ -814,8 +824,10 @@ __global__ void reduce_bn_c_last_kernel(
if
(
grad_weight
!=
NULL
)
{
grad_weight
[
c_offset
]
=
static_cast
<
layerscalar_t
>
(
sum_dy_xmu_th
*
factor
);
}
mean_dy
[
c_offset
]
=
sum_dy_th
/
reduction_size
;
mean_dy_xmu
[
c_offset
]
=
sum_dy_xmu_th
/
reduction_size
;
//mean_dy[c_offset] = sum_dy_th / reduction_size;
//mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
sum_dy_o
[
c_offset
]
=
sum_dy_th
;
sum_dy_xmu_o
[
c_offset
]
=
sum_dy_xmu_th
;
}
}
}
else
{
...
...
@@ -826,8 +838,10 @@ __global__ void reduce_bn_c_last_kernel(
if
(
grad_weight
!=
NULL
)
{
grad_weight
[
c_offset
]
=
static_cast
<
layerscalar_t
>
(
sum_dy_xmu_th
*
factor
);
}
mean_dy
[
c_offset
]
=
sum_dy_th
/
reduction_size
;
mean_dy_xmu
[
c_offset
]
=
sum_dy_xmu_th
/
reduction_size
;
//mean_dy[c_offset] = sum_dy_th / reduction_size;
//mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
sum_dy_o
[
c_offset
]
=
sum_dy_th
;
sum_dy_xmu_o
[
c_offset
]
=
sum_dy_xmu_th
;
}
}
}
...
...
@@ -844,11 +858,17 @@ __global__ void batchnorm_backward_c_last_kernel(
const
accscalar_t
*
__restrict__
mean
,
const
accscalar_t
*
__restrict__
inv_std
,
const
layerscalar_t
*
__restrict__
weight
,
const
accscalar_t
*
__restrict__
mean_dy
,
const
accscalar_t
*
__restrict__
mean_dy_xmu
,
const
accscalar_t
*
__restrict__
sum_dy
,
const
accscalar_t
*
__restrict__
sum_dy_xmu
,
const
int
*
__restrict__
numel
,
scalar_t
*
__restrict__
grad_input
,
const
int64_t
world_size
,
const
int
reduction_size
,
const
int
stride
)
{
int64_t
div
=
0
;
for
(
int
i
=
0
;
i
<
world_size
;
i
++
)
{
div
+=
numel
[
i
];
}
// tensor dimension (m,c)
// loop along m dimension
int
inner_loop_stride
=
blockDim
.
y
*
gridDim
.
y
;
...
...
@@ -858,10 +878,10 @@ __global__ void batchnorm_backward_c_last_kernel(
int
c_offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
auto
m_c
=
mean
[
c_offset
];
auto
m_dy_c
=
mean
_dy
[
c_offset
];
auto
m_dy_c
=
sum
_dy
[
c_offset
]
/
div
;
auto
factor_1_c
=
inv_std
[
c_offset
];
auto
factor_2_c
=
(
weight
==
NULL
?
accscalar_t
(
1.0
)
:
static_cast
<
accscalar_t
>
(
weight
[
c_offset
]))
*
factor_1_c
;
factor_1_c
=
factor_1_c
*
factor_1_c
*
mean
_dy_xmu
[
c_offset
];
factor_1_c
=
factor_1_c
*
factor_1_c
*
sum
_dy_xmu
[
c_offset
]
/
div
;
int
loop_count
=
1
+
(
reduction_size
-
1
)
/
(
inner_loop_stride
*
PARALLEL_LOADS
);
int
address_base
=
m_offset
*
stride
+
c_offset
;
...
...
@@ -986,8 +1006,8 @@ std::vector<at::Tensor> reduce_bn_CUDA(
auto
scalar_type
=
promote_scalartype
(
input
);
at
::
Tensor
mean
_dy
=
at
::
empty
({
feature_size
},
mean
.
options
());
at
::
Tensor
mean
_dy_xmu
=
at
::
empty
({
feature_size
},
mean
.
options
());
at
::
Tensor
sum
_dy
=
at
::
empty
({
feature_size
},
mean
.
options
());
at
::
Tensor
sum
_dy_xmu
=
at
::
empty
({
feature_size
},
mean
.
options
());
at
::
Tensor
grad_weight
;
at
::
Tensor
grad_bias
;
...
...
@@ -1018,8 +1038,8 @@ std::vector<at::Tensor> reduce_bn_CUDA(
grad_output
.
DATA_PTR
<
scalar_t_0
>
(),
mean
.
DATA_PTR
<
accscalar_t
>
(),
inv_std
.
DATA_PTR
<
accscalar_t
>
(),
mean
_dy
.
DATA_PTR
<
accscalar_t
>
(),
mean
_dy_xmu
.
DATA_PTR
<
accscalar_t
>
(),
sum
_dy
.
DATA_PTR
<
accscalar_t
>
(),
sum
_dy_xmu
.
DATA_PTR
<
accscalar_t
>
(),
weight
.
has_value
()
?
grad_weight
.
DATA_PTR
<
accscalar_t
>
()
:
NULL
,
weight
.
has_value
()
?
grad_bias
.
DATA_PTR
<
accscalar_t
>
()
:
NULL
,
batch_size
,
...
...
@@ -1039,8 +1059,8 @@ std::vector<at::Tensor> reduce_bn_CUDA(
grad_output
.
DATA_PTR
<
scalar_t_0
>
(),
mean
.
DATA_PTR
<
accscalar_t
>
(),
inv_std
.
DATA_PTR
<
accscalar_t
>
(),
mean
_dy
.
DATA_PTR
<
accscalar_t
>
(),
mean
_dy_xmu
.
DATA_PTR
<
accscalar_t
>
(),
sum
_dy
.
DATA_PTR
<
accscalar_t
>
(),
sum
_dy_xmu
.
DATA_PTR
<
accscalar_t
>
(),
weight
.
has_value
()
?
grad_weight
.
DATA_PTR
<
scalar_t_0
>
()
:
NULL
,
weight
.
has_value
()
?
grad_bias
.
DATA_PTR
<
scalar_t_0
>
()
:
NULL
,
batch_size
,
...
...
@@ -1049,7 +1069,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(
);
}
return
{
mean
_dy
,
mean
_dy_xmu
,
grad_weight
,
grad_bias
};
return
{
sum
_dy
,
sum
_dy_xmu
,
grad_weight
,
grad_bias
};
}
at
::
Tensor
batchnorm_backward_CUDA
(
...
...
@@ -1058,8 +1078,9 @@ at::Tensor batchnorm_backward_CUDA(
const
at
::
Tensor
mean
,
const
at
::
Tensor
inv_std
,
const
at
::
optional
<
at
::
Tensor
>
weight
,
const
at
::
Tensor
mean_dy
,
const
at
::
Tensor
mean_dy_xmu
)
{
const
at
::
Tensor
sum_dy
,
const
at
::
Tensor
sum_dy_xmu
,
const
at
::
Tensor
count
)
{
const
auto
batch_size
=
input
.
size
(
0
);
const
auto
feature_size
=
input
.
size
(
1
);
...
...
@@ -1088,9 +1109,11 @@ at::Tensor batchnorm_backward_CUDA(
mean
.
DATA_PTR
<
accscalar_t
>
(),
inv_std
.
DATA_PTR
<
accscalar_t
>
(),
weight
.
has_value
()
?
weight
.
value
().
DATA_PTR
<
accscalar_t
>
()
:
NULL
,
mean_dy
.
DATA_PTR
<
accscalar_t
>
(),
mean_dy_xmu
.
DATA_PTR
<
accscalar_t
>
(),
sum_dy
.
DATA_PTR
<
accscalar_t
>
(),
sum_dy_xmu
.
DATA_PTR
<
accscalar_t
>
(),
count
.
DATA_PTR
<
int
>
(),
grad_input
.
DATA_PTR
<
scalar_t_0
>
(),
count
.
numel
(),
space_size
,
batch_size
);
);
...
...
@@ -1108,9 +1131,11 @@ at::Tensor batchnorm_backward_CUDA(
mean
.
DATA_PTR
<
accscalar_t
>
(),
inv_std
.
DATA_PTR
<
accscalar_t
>
(),
weight
.
has_value
()
?
weight
.
value
().
DATA_PTR
<
scalar_t_0
>
()
:
NULL
,
mean_dy
.
DATA_PTR
<
accscalar_t
>
(),
mean_dy_xmu
.
DATA_PTR
<
accscalar_t
>
(),
sum_dy
.
DATA_PTR
<
accscalar_t
>
(),
sum_dy_xmu
.
DATA_PTR
<
accscalar_t
>
(),
count
.
DATA_PTR
<
int
>
(),
grad_input
.
DATA_PTR
<
scalar_t_0
>
(),
count
.
numel
(),
space_size
,
batch_size
);
);
...
...
@@ -1121,7 +1146,7 @@ at::Tensor batchnorm_backward_CUDA(
std
::
vector
<
at
::
Tensor
>
welford_parallel_CUDA
(
const
at
::
Tensor
mean_feature_nodes
,
const
at
::
Tensor
var_biased
,
int
numel
,
const
at
::
Tensor
numel
,
const
float
eps
)
{
const
auto
world_size
=
mean_feature_nodes
.
size
(
0
);
const
auto
feature_size
=
mean_feature_nodes
.
size
(
1
);
...
...
@@ -1142,13 +1167,13 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node
welford_kernel_parallel
<
scalar_t_0
><<<
grid
,
block
,
0
,
stream
>>>
(
mean_feature_nodes
.
DATA_PTR
<
scalar_t_0
>
(),
var_biased
.
DATA_PTR
<
scalar_t_0
>
(),
numel
.
DATA_PTR
<
int
>
(),
out_mean
.
DATA_PTR
<
scalar_t_0
>
(),
out_var
.
DATA_PTR
<
scalar_t_0
>
(),
inv_std
.
DATA_PTR
<
scalar_t_0
>
(),
world_size
,
feature_size
,
eps
,
numel
);
eps
);
);
}
...
...
@@ -1270,8 +1295,8 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
const
auto
stride
=
input
.
size
(
input
.
ndimension
()
-
1
);
const
auto
reduction_size
=
input
.
numel
()
/
stride
;
at
::
Tensor
mea
n_dy
=
at
::
empty
({
stride
},
mean
.
options
());
at
::
Tensor
mean
_dy_xmu
=
at
::
empty
({
stride
},
mean
.
options
());
at
::
Tensor
sum
n_dy
=
at
::
empty
({
stride
},
mean
.
options
());
at
::
Tensor
sum
_dy_xmu
=
at
::
empty
({
stride
},
mean
.
options
());
at
::
Tensor
grad_weight
;
at
::
Tensor
grad_bias
;
...
...
@@ -1310,8 +1335,8 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
grad_output
.
DATA_PTR
<
scalar_t_0
>
(),
mean
.
DATA_PTR
<
accscalar_t
>
(),
inv_std
.
DATA_PTR
<
accscalar_t
>
(),
mea
n_dy
.
DATA_PTR
<
accscalar_t
>
(),
mean
_dy_xmu
.
DATA_PTR
<
accscalar_t
>
(),
sum
n_dy
.
DATA_PTR
<
accscalar_t
>
(),
sum
_dy_xmu
.
DATA_PTR
<
accscalar_t
>
(),
weight
.
has_value
()
?
grad_weight
.
DATA_PTR
<
accscalar_t
>
()
:
NULL
,
weight
.
has_value
()
?
grad_bias
.
DATA_PTR
<
accscalar_t
>
()
:
NULL
,
staging_data_ptr
,
...
...
@@ -1335,8 +1360,8 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
grad_output
.
DATA_PTR
<
scalar_t_0
>
(),
mean
.
DATA_PTR
<
accscalar_t
>
(),
inv_std
.
DATA_PTR
<
accscalar_t
>
(),
mea
n_dy
.
DATA_PTR
<
accscalar_t
>
(),
mean
_dy_xmu
.
DATA_PTR
<
accscalar_t
>
(),
sum
n_dy
.
DATA_PTR
<
accscalar_t
>
(),
sum
_dy_xmu
.
DATA_PTR
<
accscalar_t
>
(),
weight
.
has_value
()
?
grad_weight
.
DATA_PTR
<
scalar_t_0
>
()
:
NULL
,
weight
.
has_value
()
?
grad_bias
.
DATA_PTR
<
scalar_t_0
>
()
:
NULL
,
staging_data_ptr
,
...
...
@@ -1346,7 +1371,7 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
);
}
return
{
mea
n_dy
,
mean
_dy_xmu
,
grad_weight
,
grad_bias
};
return
{
sum
n_dy
,
sum
_dy_xmu
,
grad_weight
,
grad_bias
};
}
at
::
Tensor
batchnorm_backward_c_last_CUDA
(
...
...
@@ -1355,8 +1380,9 @@ at::Tensor batchnorm_backward_c_last_CUDA(
const
at
::
Tensor
mean
,
const
at
::
Tensor
inv_std
,
const
at
::
optional
<
at
::
Tensor
>
weight
,
const
at
::
Tensor
mean_dy
,
const
at
::
Tensor
mean_dy_xmu
)
{
const
at
::
Tensor
sum_dy
,
const
at
::
Tensor
sum_dy_xmu
,
const
at
::
Tensor
count
)
{
const
auto
stride
=
input
.
size
(
input
.
ndimension
()
-
1
);
const
auto
reduction_size
=
input
.
numel
()
/
stride
;
...
...
@@ -1380,9 +1406,11 @@ at::Tensor batchnorm_backward_c_last_CUDA(
mean
.
DATA_PTR
<
accscalar_t
>
(),
inv_std
.
DATA_PTR
<
accscalar_t
>
(),
weight
.
has_value
()
?
weight
.
value
().
DATA_PTR
<
accscalar_t
>
()
:
NULL
,
mean_dy
.
DATA_PTR
<
accscalar_t
>
(),
mean_dy_xmu
.
DATA_PTR
<
accscalar_t
>
(),
sum_dy
.
DATA_PTR
<
accscalar_t
>
(),
sum_dy_xmu
.
DATA_PTR
<
accscalar_t
>
(),
count
.
DATA_PTR
<
int
>
(),
grad_input
.
DATA_PTR
<
scalar_t_0
>
(),
count
.
numel
(),
reduction_size
,
stride
);
);
...
...
@@ -1401,9 +1429,11 @@ at::Tensor batchnorm_backward_c_last_CUDA(
mean
.
DATA_PTR
<
accscalar_t
>
(),
inv_std
.
DATA_PTR
<
accscalar_t
>
(),
weight
.
has_value
()
?
weight
.
value
().
DATA_PTR
<
scalar_t_0
>
()
:
NULL
,
mean_dy
.
DATA_PTR
<
accscalar_t
>
(),
mean_dy_xmu
.
DATA_PTR
<
accscalar_t
>
(),
sum_dy
.
DATA_PTR
<
accscalar_t
>
(),
sum_dy_xmu
.
DATA_PTR
<
accscalar_t
>
(),
count
.
DATA_PTR
<
int
>
(),
grad_input
.
DATA_PTR
<
scalar_t_0
>
(),
count
.
numel
(),
reduction_size
,
stride
);
);
...
...
tests/L0/run_optimizers/test_lamb.py
0 → 100644
View file @
3dd36070
import
unittest
import
os
import
torch
from
torch.optim
import
Optimizer
import
apex
from
apex.multi_tensor_apply
import
multi_tensor_applier
class
RefLAMB
(
Optimizer
):
r
"""Implements Lamb algorithm.
It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-6)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0.01)
.. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
"""
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-6
,
weight_decay
=
0.01
):
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
if
not
0.0
<=
betas
[
0
]
<
1.0
:
raise
ValueError
(
"Invalid beta parameter at index 0: {}"
.
format
(
betas
[
0
]))
if
not
0.0
<=
betas
[
1
]
<
1.0
:
raise
ValueError
(
"Invalid beta parameter at index 1: {}"
.
format
(
betas
[
1
]))
defaults
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
)
super
(
RefLAMB
,
self
).
__init__
(
params
,
defaults
)
if
multi_tensor_applier
.
available
:
import
amp_C
self
.
multi_tensor_l2norm
=
amp_C
.
multi_tensor_l2norm
# Skip buffer
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
multi_tensor_lamb
=
amp_C
.
multi_tensor_lamb
else
:
raise
RuntimeError
(
'apex.optimizers.FusedLAMB requires cuda extensions'
)
def
step
(
self
,
closure
=
None
):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss
=
None
if
closure
is
not
None
:
loss
=
closure
()
# create separate grad lists for fp32 and fp16 params
g_all_32
,
g_all_16
=
[],
[]
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
continue
if
p
.
dtype
==
torch
.
float32
:
g_all_32
.
append
(
p
.
grad
.
data
)
elif
p
.
dtype
==
torch
.
float16
:
g_all_16
.
append
(
p
.
grad
.
data
)
else
:
raise
RuntimeError
(
'FusedLAMB only support fp16 and fp32.'
)
g_norm_32
,
g_norm_16
=
torch
.
zeros
(
1
,
device
=
'cuda'
),
torch
.
zeros
(
1
,
device
=
'cuda'
)
# compute grad norm for two lists
if
len
(
g_all_32
)
>
0
:
g_norm_32
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_dummy_overflow_buf
,
[
g_all_32
],
False
)[
0
]
if
len
(
g_all_16
)
>
0
:
g_norm_16
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_dummy_overflow_buf
,
[
g_all_16
],
False
)[
0
]
# blend two grad norms to get global grad norm
global_grad_norm
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_dummy_overflow_buf
,
[[
g_norm_32
,
g_norm_16
]],
False
)[
0
]
max_grad_norm
=
1.0
clipped_ratio
=
max_grad_norm
/
max
(
global_grad_norm
,
max_grad_norm
)
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
continue
p
.
grad
.
data
*=
clipped_ratio
grad
=
p
.
grad
.
data
if
grad
.
is_sparse
:
raise
RuntimeError
(
'Lamb does not support sparse gradients, consider SparseAdam instad.'
)
state
=
self
.
state
[
p
]
# State initialization
if
len
(
state
)
==
0
:
state
[
'step'
]
=
0
# Exponential moving average of gradient values
state
[
'm'
]
=
torch
.
zeros_like
(
p
.
data
)
# Exponential moving average of squared gradient values
state
[
'v'
]
=
torch
.
zeros_like
(
p
.
data
)
m_t
,
v_t
=
state
[
'm'
],
state
[
'v'
]
beta1
,
beta2
=
group
[
'betas'
]
state
[
'step'
]
+=
1
# m_t = beta1 * m + (1 - beta1) * g_t
m_t
.
mul_
(
beta1
).
add_
(
grad
,
alpha
=
1
-
beta1
)
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v_t
.
mul_
(
beta2
).
addcmul_
(
grad
,
grad
,
value
=
1
-
beta2
)
# Debiasing
m_t_hat
=
m_t
/
(
1.0
-
beta1
**
state
[
'step'
])
v_t_hat
=
v_t
/
(
1.0
-
beta2
**
state
[
'step'
])
update
=
m_t_hat
/
v_t_hat
.
sqrt
().
add
(
group
[
'eps'
])
if
group
[
'weight_decay'
]
!=
0
:
update
.
add_
(
p
.
data
,
alpha
=
group
[
'weight_decay'
])
trust_ratio
=
1.0
w_norm
=
p
.
data
.
pow
(
2
).
sum
().
sqrt
()
g_norm
=
update
.
pow
(
2
).
sum
().
sqrt
()
if
w_norm
>
0
and
g_norm
>
0
:
trust_ratio
=
w_norm
/
g_norm
state
[
'w_norm'
]
=
w_norm
state
[
'g_norm'
]
=
g_norm
state
[
'trust_ratio'
]
=
trust_ratio
step_size
=
group
[
'lr'
]
p
.
data
.
add_
(
update
,
alpha
=-
step_size
*
trust_ratio
)
return
loss
class
TestFusedLAMB
(
unittest
.
TestCase
):
def
setUp
(
self
,
max_abs_diff
=
1e-3
,
max_rel_diff
=
1
,
iters
=
7
):
self
.
max_abs_diff
=
max_abs_diff
self
.
max_rel_diff
=
max_rel_diff
self
.
iters
=
iters
torch
.
cuda
.
manual_seed
(
9876
)
def
tearDown
(
self
):
pass
def
gen_param_optim
(
self
,
tensors
,
lamb_option
):
ref_param
=
[]
tst_param
=
[]
for
tensor
in
tensors
:
ref_param
.
append
(
torch
.
nn
.
Parameter
(
tensor
.
clone
()))
tst_param
.
append
(
torch
.
nn
.
Parameter
(
tensor
.
clone
()))
ref_optim
=
RefLAMB
(
ref_param
,
**
lamb_option
)
tst_optim
=
apex
.
optimizers
.
FusedLAMB
(
tst_param
,
use_nvlamb
=
True
,
**
lamb_option
)
return
(
ref_param
,
tst_param
,
ref_optim
,
tst_optim
)
def
gen_grad
(
self
,
ref_param
,
tst_param
):
for
p_ref
,
p_tst
in
zip
(
ref_param
,
tst_param
):
p_ref
.
grad
=
torch
.
rand_like
(
p_ref
)
p_tst
.
grad
=
p_ref
.
grad
def
gen_mixed_grad
(
self
,
ref_param
,
tst_param
,
scale
=
1.0
):
half_grads
=
[]
for
p_ref
,
_
in
zip
(
ref_param
,
tst_param
):
half_grads
.
append
(
torch
.
rand_like
(
p_ref
).
half
())
p_ref
.
grad
=
half_grads
[
-
1
].
float
()
/
scale
return
half_grads
def
get_max_diff
(
self
,
ref_param
,
tst_param
):
max_abs_diff
=
max_rel_diff
=
0
for
p_ref
,
p_tst
in
zip
(
ref_param
,
tst_param
):
max_abs_diff_p
=
(
p_ref
-
p_tst
).
abs
().
max
().
item
()
max_rel_diff_p
=
((
p_ref
-
p_tst
)
/
p_ref
).
abs
().
max
().
item
()
if
max_abs_diff_p
>
max_abs_diff
:
max_abs_diff
=
max_abs_diff_p
if
max_rel_diff_p
>
max_rel_diff
:
max_rel_diff
=
max_rel_diff_p
return
max_abs_diff
,
max_rel_diff
def
gen_single_type_test
(
self
,
param_type
=
torch
.
float
):
nelem
=
278011
tensor
=
torch
.
rand
(
nelem
,
dtype
=
param_type
,
device
=
'cuda'
)
weight_decay
=
[
0
,
0.01
]
for
wd
in
weight_decay
:
lamb_option
=
{
'lr'
:
5e-4
,
'betas'
:(
0.9
,
0.999
),
'eps'
:
1e-08
,
'weight_decay'
:
wd
}
ref_param
,
tst_param
,
ref_optim
,
tst_optim
=
\
self
.
gen_param_optim
([
tensor
],
lamb_option
)
for
i
in
range
(
self
.
iters
):
self
.
gen_grad
(
ref_param
,
tst_param
)
ref_optim
.
step
()
tst_optim
.
step
()
max_abs_diff
,
max_rel_diff
=
self
.
get_max_diff
(
ref_param
,
tst_param
)
self
.
assertLessEqual
(
max_abs_diff
,
self
.
max_abs_diff
)
self
.
assertLessEqual
(
max_rel_diff
,
self
.
max_rel_diff
)
def
test_float
(
self
):
self
.
gen_single_type_test
(
param_type
=
torch
.
float
)
@
unittest
.
skip
(
"PyTorch optimizer is not numerically correct for fp16"
)
def
test_half
(
self
):
self
.
gen_single_type_test
(
param_type
=
torch
.
float16
)
def
test_multi_params
(
self
):
sizes
=
[[
4096
,
1024
],
[
4096
],
[
4096
,
2048
],
[
32320
,
1024
],
[
1
]]
weight_decay
=
[
0
,
0.01
]
for
wd
in
weight_decay
:
lamb_option
=
{
'lr'
:
5e-4
,
'betas'
:(
0.9
,
0.999
),
'eps'
:
1e-08
,
'weight_decay'
:
wd
}
tensors
=
[]
for
size
in
sizes
:
tensors
.
append
(
torch
.
rand
(
size
,
dtype
=
torch
.
float
,
device
=
'cuda'
))
ref_param
,
tst_param
,
ref_optim
,
tst_optim
=
\
self
.
gen_param_optim
(
tensors
,
lamb_option
)
for
i
in
range
(
self
.
iters
):
self
.
gen_grad
(
ref_param
,
tst_param
)
ref_optim
.
step
()
tst_optim
.
step
()
max_abs_diff
,
max_rel_diff
=
self
.
get_max_diff
(
ref_param
,
tst_param
)
self
.
assertLessEqual
(
max_abs_diff
,
self
.
max_abs_diff
)
self
.
assertLessEqual
(
max_rel_diff
,
self
.
max_rel_diff
)
def
test_lamb_option
(
self
):
nelem
=
1
tensor
=
torch
.
rand
(
nelem
,
dtype
=
torch
.
float
,
device
=
'cuda'
)
weight_decay
=
[
0
,
0.01
]
for
wd
in
weight_decay
:
lamb_option
=
{
'lr'
:
0.01
,
'betas'
:(
0.6
,
0.9
),
'eps'
:
3e-06
,
'weight_decay'
:
wd
}
ref_param
,
tst_param
,
ref_optim
,
tst_optim
=
\
self
.
gen_param_optim
([
tensor
],
lamb_option
)
for
i
in
range
(
self
.
iters
):
self
.
gen_grad
(
ref_param
,
tst_param
)
ref_optim
.
step
()
tst_optim
.
step
()
max_abs_diff
,
max_rel_diff
=
self
.
get_max_diff
(
ref_param
,
tst_param
)
self
.
assertLessEqual
(
max_abs_diff
,
self
.
max_abs_diff
)
self
.
assertLessEqual
(
max_rel_diff
,
self
.
max_rel_diff
)
if
__name__
==
'__main__'
:
script_path
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
unittest
.
main
()
tests/distributed/synced_batchnorm/single_gpu_unit_test.py
View file @
3dd36070
...
...
@@ -35,6 +35,7 @@ inp = (np.random.randn(batch_size, feature_size, space_size, space_size)).astype
grad
=
(
np
.
random
.
randn
(
batch_size
,
feature_size
,
space_size
,
space_size
)).
astype
(
dtype
)
weight
=
(
np
.
random
.
randn
(
feature_size
)).
astype
(
dtype
)
bias
=
(
np
.
random
.
randn
(
feature_size
)).
astype
(
dtype
)
count
=
torch
.
cuda
.
IntTensor
([
batch_size
*
space_size
**
2
])
type_tensor
=
torch
.
cuda
.
FloatTensor
ref_tensor
=
torch
.
cuda
.
DoubleTensor
...
...
@@ -110,17 +111,19 @@ grad_output2_r = ref_tensor(grad)
grad_bias_r
=
grad_output_r
.
sum
(
1
)
grad_weight_r
=
((
inp2_r
-
m
.
view
(
-
1
,
1
,
1
))
*
torch
.
rsqrt
(
b_v
.
view
(
-
1
,
1
,
1
)
+
eps
)
*
grad_output2_r
).
transpose
(
1
,
0
).
contiguous
().
view
(
feature_size
,
-
1
).
sum
(
1
)
sum_dy_r
=
grad_output_r
.
sum
(
1
)
mean_dy_r
=
grad_output_r
.
mean
(
1
)
sum_dy_xmu_r
=
((
inp2_r
-
m
.
view
(
-
1
,
1
,
1
))
*
grad_output2_r
).
transpose
(
1
,
0
).
contiguous
().
view
(
feature_size
,
-
1
).
sum
(
1
)
mean_dy_xmu_r
=
((
inp2_r
-
m
.
view
(
-
1
,
1
,
1
))
*
grad_output2_r
).
transpose
(
1
,
0
).
contiguous
().
view
(
feature_size
,
-
1
).
mean
(
1
)
grad_input_r
=
(
grad_output2_r
-
mean_dy_r
.
view
(
-
1
,
1
,
1
)
-
(
inp2_r
-
m
.
view
(
-
1
,
1
,
1
))
/
(
b_v
.
view
(
-
1
,
1
,
1
)
+
eps
)
*
mean_dy_xmu_r
.
view
(
-
1
,
1
,
1
)
)
*
torch
.
rsqrt
(
b_v
.
view
(
-
1
,
1
,
1
)
+
eps
)
*
weight_r
.
view
(
-
1
,
1
,
1
)
mean
_dy
,
mean
_dy_xmu
,
grad_weight
,
grad_bias
=
syncbn
.
reduce_bn
(
grad_output_t
,
inp_t
,
mean
,
inv_std
,
weight_t
)
grad_input
=
syncbn
.
batchnorm_backward
(
grad_output_t
,
inp_t
,
mean
,
inv_std
,
weight_t
,
mean
_dy
,
mean
_dy_xmu
)
sum
_dy
,
sum
_dy_xmu
,
grad_weight
,
grad_bias
=
syncbn
.
reduce_bn
(
grad_output_t
,
inp_t
,
mean
,
inv_std
,
weight_t
)
grad_input
=
syncbn
.
batchnorm_backward
(
grad_output_t
,
inp_t
,
mean
,
inv_std
,
weight_t
,
sum
_dy
,
sum
_dy_xmu
,
count
)
sbn_result
=
compare
(
"comparing bias grad: "
,
grad_bias
,
grad_bias_r
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing weight grad: "
,
grad_weight
,
grad_weight_r
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing
mean
_dy grad: "
,
mean
_dy
,
mean
_dy_r
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing
mean
_dy_xmu grad: "
,
mean
_dy_xmu
,
mean
_dy_xmu_r
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing
sum
_dy grad: "
,
sum
_dy
,
sum
_dy_r
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing
sum
_dy_xmu grad: "
,
sum
_dy_xmu
,
sum
_dy_xmu_r
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing input grad: "
,
grad_input
,
grad_input_r
,
error
)
and
sbn_result
compare
(
"comparing bn input grad: "
,
inp_bn
.
grad
,
grad_input_r
,
error
)
sbn_result
=
compare
(
"comparing sbn input grad: "
,
inp_sbn
.
grad
,
grad_input_r
,
error
)
and
sbn_result
...
...
tests/distributed/synced_batchnorm/two_gpu_test_different_batch_size.py
0 → 100755
View file @
3dd36070
import
torch
import
torch.nn
as
nn
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
apex.parallel
import
SyncBatchNorm
as
ApexSyncBatchNorm
import
argparse
import
os
import
numpy
as
np
var_batch
=
16
def
compare
(
desc
,
inp1
,
inp2
,
error
=
1e-5
):
a
=
inp1
.
clone
().
detach
().
cpu
().
numpy
()
b
=
inp2
.
clone
().
detach
().
cpu
().
numpy
()
close
=
np
.
allclose
(
a
,
b
,
error
,
error
)
if
not
close
:
print
(
desc
,
close
)
z
=
a
-
b
index
=
(
np
.
abs
(
z
)
>=
error
+
error
*
np
.
abs
(
b
)).
nonzero
()
print
(
"dif : "
,
z
[
index
])
print
(
"inp1 : "
,
a
[
index
])
print
(
"inp2 : "
,
b
[
index
])
return
close
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'--apex'
,
action
=
'store_true'
)
args
=
parser
.
parse_args
()
torch
.
manual_seed
(
2809
)
# Setup DDP
torch
.
cuda
.
set_device
(
args
.
local_rank
)
device
=
torch
.
device
(
'cuda:{}'
.
format
(
args
.
local_rank
))
torch
.
distributed
.
init_process_group
(
'nccl'
,
init_method
=
'env://'
,
rank
=
args
.
local_rank
,
)
# Setup model
if
args
.
apex
:
model
=
nn
.
Sequential
(
nn
.
Conv2d
(
3
,
6
,
3
,
1
,
1
),
ApexSyncBatchNorm
(
6
)
)
else
:
model
=
nn
.
Sequential
(
nn
.
Conv2d
(
3
,
6
,
3
,
1
,
1
),
nn
.
SyncBatchNorm
(
6
)
)
# Setup reference model
model_reference
=
nn
.
Sequential
(
nn
.
Conv2d
(
3
,
6
,
3
,
1
,
1
),
nn
.
BatchNorm2d
(
6
)
)
with
torch
.
no_grad
():
model_reference
[
0
].
weight
.
copy_
(
model
[
0
].
weight
)
model_reference
[
0
].
bias
.
copy_
(
model
[
0
].
bias
)
model_reference
.
to
(
device
)
model
=
model
.
to
(
device
)
model
=
DDP
(
model
,
device_ids
=
[
args
.
local_rank
],
output_device
=
args
.
local_rank
)
global_batch_size
=
var_batch
+
8
# Create random data
if
args
.
local_rank
==
0
:
data
=
torch
.
randn
(
var_batch
,
3
,
8
,
8
,
device
=
device
,
dtype
=
torch
.
float
)
*
50.0
grad
=
torch
.
randint
(
0
,
10
,
(
var_batch
,
6
,
8
,
8
),
device
=
device
,
dtype
=
torch
.
float
)
/
10.0
else
:
data
=
torch
.
randn
(
8
,
3
,
8
,
8
,
device
=
device
)
grad
=
torch
.
randint
(
0
,
10
,
(
8
,
6
,
8
,
8
),
device
=
device
,
dtype
=
torch
.
float
)
/
10.0
data
.
requires_grad_
()
data
.
retain_grad
=
True
weighted_gradient
=
True
# DDP forward/backward
output
=
model
(
data
)
if
weighted_gradient
:
output
.
backward
(
grad
*
2
/
global_batch_size
)
else
:
output
.
backward
(
grad
/
output
.
size
(
0
))
d_list
=
[
torch
.
randn
(
8
,
3
,
8
,
8
,
device
=
device
)
for
i
in
range
(
int
(
os
.
environ
[
'WORLD_SIZE'
]))]
y_list
=
[
torch
.
randn
(
8
,
6
,
8
,
8
,
device
=
device
)
for
i
in
range
(
int
(
os
.
environ
[
'WORLD_SIZE'
]))]
dgrad_list
=
[
torch
.
randn
(
8
,
3
,
8
,
8
,
device
=
device
)
for
i
in
range
(
int
(
os
.
environ
[
'WORLD_SIZE'
]))]
grad_list
=
[
torch
.
randn
(
8
,
6
,
8
,
8
,
device
=
device
)
for
i
in
range
(
int
(
os
.
environ
[
'WORLD_SIZE'
]))]
if
args
.
local_rank
==
0
:
# placeholder, these random data will later be discarded.
torch
.
distributed
.
all_gather
(
d_list
,
torch
.
randn
(
8
,
3
,
8
,
8
,
device
=
device
))
torch
.
distributed
.
all_gather
(
y_list
,
torch
.
randn
(
8
,
6
,
8
,
8
,
device
=
device
))
torch
.
distributed
.
all_gather
(
dgrad_list
,
torch
.
randn
(
8
,
3
,
8
,
8
,
device
=
device
))
torch
.
distributed
.
all_gather
(
grad_list
,
torch
.
randn
(
8
,
6
,
8
,
8
,
device
=
device
))
else
:
torch
.
distributed
.
all_gather
(
d_list
,
data
)
torch
.
distributed
.
all_gather
(
y_list
,
output
)
torch
.
distributed
.
all_gather
(
dgrad_list
,
data
.
grad
)
torch
.
distributed
.
all_gather
(
grad_list
,
grad
)
torch
.
distributed
.
barrier
()
if
args
.
local_rank
==
0
:
ref_tensor
=
d_list
[
1
:]
ref_tensor
.
insert
(
0
,
data
)
assert
(
ref_tensor
[
0
].
equal
(
data
))
ref_tensor
=
torch
.
cat
(
ref_tensor
,
0
)
ref_tensor
=
ref_tensor
.
detach
()
ref_tensor
.
requires_grad_
()
ref_tensor
.
retain_grad
()
# Reference forward/backward
output_reference
=
model_reference
(
ref_tensor
)
grad_tensor
=
grad_list
[
1
:]
grad_tensor
.
insert
(
0
,
grad
)
assert
(
grad_tensor
[
0
].
equal
(
grad
))
grad_tensor
=
torch
.
cat
(
grad_tensor
,
0
)
if
weighted_gradient
:
output_reference
.
backward
(
grad_tensor
/
output_reference
.
size
(
0
))
else
:
output_reference
.
backward
(
grad_tensor
/
output_reference
.
size
(
0
))
dgrad_tensor
=
dgrad_list
[
1
:]
dgrad_tensor
.
insert
(
0
,
data
.
grad
)
dgrad_tensor
=
torch
.
cat
(
dgrad_tensor
,
0
)
# check output
output_tensor
=
y_list
[
1
:]
output_tensor
.
insert
(
0
,
output
)
output_tensor
=
torch
.
cat
(
output_tensor
,
0
)
passed
=
True
passed
=
passed
and
compare
(
"check output"
,
output_tensor
,
output_reference
)
# check stats
passed
=
passed
and
compare
(
"check running mean failed"
,
model_reference
[
1
].
running_mean
,
model
.
module
[
1
].
running_mean
)
passed
=
passed
and
compare
(
"check running var failed"
,
model_reference
[
1
].
running_var
,
model
.
module
[
1
].
running_var
)
passed
=
passed
and
compare
(
"bn wgrad check failed!"
,
model_reference
[
1
].
weight
.
grad
,
model
.
module
[
1
].
weight
.
grad
,
1e-6
)
passed
=
passed
and
compare
(
"conv wgrad check failed!"
,
model_reference
[
0
].
weight
.
grad
,
model
.
module
[
0
].
weight
.
grad
)
# can't really compare dgrad directly, as we need to scale it to account for
# DDP
# passed = passed and compare("dgrad check failed!", ref_tensor.grad, dgrad_tensor)
if
passed
:
print
(
"====SBN two gpu with different batches test passed"
)
else
:
assert
(
"*failed two gpu with different batches tests*"
)
tests/distributed/synced_batchnorm/two_gpu_unit_test.py
View file @
3dd36070
...
...
@@ -114,6 +114,11 @@ grad_sbn = grad_output_t.clone().detach()
out_sbn
=
sbn
(
inp_sbn
[
start
:
finish
])
out_sbn
.
backward
(
grad_sbn
[
start
:
finish
])
count
=
[
space_size
**
2
*
(
(
i
+
1
)
*
batch_size
//
args
.
world_size
-
i
*
batch_size
//
args
.
world_size
)
for
i
in
range
(
0
,
args
.
world_size
)]
count
=
torch
.
cuda
.
IntTensor
(
count
)
print
(
"--- count : "
,
count
)
sbn_result
=
True
bn_result
=
True
...
...
@@ -136,18 +141,20 @@ grad_output2_r = ref_tensor(grad)
grad_bias_r
=
grad_output_r
.
sum
(
1
)
grad_weight_r
=
((
inp2_r
-
m
.
view
(
-
1
,
1
,
1
))
*
torch
.
rsqrt
(
b_v
.
view
(
-
1
,
1
,
1
)
+
eps
)
*
grad_output2_r
).
transpose
(
1
,
0
).
contiguous
().
view
(
feature_size
,
-
1
).
sum
(
1
)
sum_dy_r
=
grad_output_r
.
sum
(
1
)
mean_dy_r
=
grad_output_r
.
mean
(
1
)
mean_dy_xmu_r
=
((
inp2_r
-
m
.
view
(
-
1
,
1
,
1
))
*
grad_output2_r
).
transpose
(
1
,
0
).
contiguous
().
view
(
feature_size
,
-
1
).
mean
(
1
)
sum_dy_xmu_r
=
((
inp2_r
-
m
.
view
(
-
1
,
1
,
1
))
*
grad_output2_r
).
transpose
(
1
,
0
).
contiguous
().
view
(
feature_size
,
-
1
).
sum
(
1
)
grad_input_r
=
(
grad_output2_r
-
mean_dy_r
.
view
(
-
1
,
1
,
1
)
-
(
inp2_r
-
m
.
view
(
-
1
,
1
,
1
))
/
(
b_v
.
view
(
-
1
,
1
,
1
)
+
eps
)
*
mean_dy_xmu_r
.
view
(
-
1
,
1
,
1
)
)
*
torch
.
rsqrt
(
b_v
.
view
(
-
1
,
1
,
1
)
+
eps
)
*
weight_r
.
view
(
-
1
,
1
,
1
)
mean
_dy
,
mean
_dy_xmu
,
grad_weight
,
grad_bias
=
syncbn
.
reduce_bn
(
grad_output_t
,
inp_t
,
mean
,
inv_std
,
weight_t
)
grad_input
=
syncbn
.
batchnorm_backward
(
grad_output_t
,
inp_t
,
mean
,
inv_std
,
weight_t
,
mean
_dy
,
mean
_dy_xmu
)
sum
_dy
,
sum
_dy_xmu
,
grad_weight
,
grad_bias
=
syncbn
.
reduce_bn
(
grad_output_t
,
inp_t
,
mean
,
inv_std
,
weight_t
)
grad_input
=
syncbn
.
batchnorm_backward
(
grad_output_t
,
inp_t
,
mean
,
inv_std
,
weight_t
,
sum
_dy
,
sum
_dy_xmu
,
count
)
if
args
.
local_rank
==
0
:
sbn_result
=
compare
(
"comparing bias grad: "
,
grad_bias
,
grad_bias_r
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing weight grad: "
,
grad_weight
,
grad_weight_r
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing
mean
_dy grad: "
,
mean
_dy
,
mean
_dy_r
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing
mean
_dy_xmu grad: "
,
mean
_dy_xmu
,
mean
_dy_xmu_r
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing
sum
_dy grad: "
,
sum
_dy
,
sum
_dy_r
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing
sum
_dy_xmu grad: "
,
sum
_dy_xmu
,
sum
_dy_xmu_r
,
error
)
and
sbn_result
sbn_result
=
compare
(
"comparing input grad: "
,
grad_input
,
grad_input_r
,
error
)
and
sbn_result
compare
(
"comparing bn input grad: "
,
inp_bn
.
grad
,
grad_input_r
,
error
)
...
...
tests/distributed/synced_batchnorm/unit_test.sh
View file @
3dd36070
...
...
@@ -3,5 +3,6 @@ python single_gpu_unit_test.py
python test_batchnorm1d.py
python
-m
torch.distributed.launch
--nproc_per_node
=
2 two_gpu_unit_test.py
python
-m
torch.distributed.launch
--nproc_per_node
=
2 two_gpu_unit_test.py
--fp16
python
-m
torch.distributed.launch
--nproc_per_node
=
2 two_gpu_test_different_batch_size.py
--apex
#beware, you need a system with at least 4 gpus to test group_size<world_size
#python -m torch.distributed.launch --nproc_per_node=4 test_groups.py --group_size=2
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment