Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
778808eb
Commit
778808eb
authored
Mar 24, 2022
by
Thor Johnsen
Browse files
Halo exchangers
parent
3ade5b26
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
335 additions
and
57 deletions
+335
-57
apex/contrib/bottleneck/__init__.py
apex/contrib/bottleneck/__init__.py
+1
-0
apex/contrib/bottleneck/bottleneck.py
apex/contrib/bottleneck/bottleneck.py
+1
-57
apex/contrib/bottleneck/halo_exchangers.py
apex/contrib/bottleneck/halo_exchangers.py
+53
-0
apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp
apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp
+26
-0
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
+207
-0
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh
+47
-0
No files found.
apex/contrib/bottleneck/__init__.py
View file @
778808eb
from
.bottleneck
import
Bottleneck
,
SpatialBottleneck
from
.bottleneck
import
Bottleneck
,
SpatialBottleneck
from
.halo_exchangers
import
HaloExchangerNoComm
,
HaloExchangerAllGather
,
HaloExchangerSendRecv
apex/contrib/bottleneck/bottleneck.py
View file @
778808eb
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch
import
nn
from
torch
import
nn
from
maskrcnn_benchmark.utils.registry
import
Registry
import
fast_bottleneck
import
maskrcnn_benchmark.SpatialBottleneck
as
fast_bottleneck
import
nccl_p2p
as
inc
import
nccl_p2p
as
inc
def
kaiming_uniform_
(
tensor
,
a
=
0
,
mode
=
'fan_in'
,
nonlinearity
=
'leaky_relu'
):
def
kaiming_uniform_
(
tensor
,
a
=
0
,
mode
=
'fan_in'
,
nonlinearity
=
'leaky_relu'
):
...
@@ -392,55 +391,6 @@ class SpatialBottleneckFunction(torch.autograd.Function):
...
@@ -392,55 +391,6 @@ class SpatialBottleneckFunction(torch.autograd.Function):
spatial_bottleneck_function
=
SpatialBottleneckFunction
.
apply
spatial_bottleneck_function
=
SpatialBottleneckFunction
.
apply
# Communication free halo exchanger.
# NB! This halo exchanger does not exchange halos with neighbors as it should, it merely swaps the inputs
# NB! This is only useful for performance testing.
# NB! Do not use for actual production runs
class
HaloExchanger
(
object
):
def
__init__
(
self
):
self
.
stream1
=
torch
.
cuda
.
Stream
()
self
.
stream2
=
torch
.
cuda
.
Stream
()
class
HaloExchangerNoComm
(
HaloExchanger
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
super
(
HaloExchangerNoComm
,
self
).
__init__
()
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
):
return
right_output_halo
,
left_output_halo
class
HaloExchangerAllGather
(
HaloExchanger
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
super
(
HaloExchangerAllGather
,
self
).
__init__
()
self
.
spatial_group_size
=
spatial_group_size
self
.
local_rank
=
rank
%
spatial_group_size
self
.
comm
=
comm
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
):
N
,
Hh
,
W
,
C
=
list
(
left_output_halo
.
shape
)
send_halos
=
torch
.
empty
((
N
,
2
*
Hh
,
W
,
C
),
dtype
=
left_output_halo
.
dtype
,
device
=
left_output_halo
.
device
)
send_halos
[:,:
Hh
,:,:].
copy_
(
left_output_halo
)
send_halos
[:,
Hh
:,:,:].
copy_
(
right_output_halo
)
all_halos
=
torch
.
empty
((
N
,
2
*
Hh
*
self
.
spatial_group_size
,
W
,
C
),
dtype
=
left_output_halo
.
dtype
,
device
=
left_output_halo
.
device
)
all_halos
=
[
all_halos
[:,
i
*
2
*
Hh
:(
i
+
1
)
*
2
*
Hh
,:,:]
for
i
in
range
(
self
.
spatial_group_size
)]
torch
.
distributed
.
all_gather
(
all_halos
,
send_halos
,
group
=
self
.
comm
,
no_copy
=
True
)
left_input_halo
=
all_halos
[(
self
.
spatial_group_size
+
self
.
local_rank
-
1
)
%
self
.
spatial_group_size
][:,
Hh
:,:,:]
right_input_halo
=
all_halos
[(
self
.
local_rank
+
1
)
%
self
.
spatial_group_size
][:,:
Hh
,:,:]
return
left_input_halo
,
right_input_halo
class
HaloExchangerSendRecv
(
HaloExchanger
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
super
(
HaloExchangerSendRecv
,
self
).
__init__
()
self
.
world_size
=
world_size
self
.
spatial_group_size
=
spatial_group_size
nccl_id
=
inc
.
get_unique_nccl_id
(
1
).
cuda
()
torch
.
distributed
.
broadcast
(
nccl_id
,
0
)
nccl_id
=
nccl_id
.
cpu
()
self
.
handle
=
inc
.
init_nccl_comm
(
nccl_id
,
rank
,
world_size
)
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
):
left_input_halo
,
right_input_halo
=
inc
.
left_right_halo_exchange
(
self
.
handle
,
left_output_halo
,
right_output_halo
,
self
.
spatial_group_size
)
return
left_input_halo
,
right_input_halo
class
SpatialBottleneck
(
torch
.
nn
.
Module
):
class
SpatialBottleneck
(
torch
.
nn
.
Module
):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
...
@@ -553,9 +503,3 @@ class SpatialBottleneck(torch.nn.Module):
...
@@ -553,9 +503,3 @@ class SpatialBottleneck(torch.nn.Module):
return
out
return
out
_HALO_EXCHANGERS
=
Registry
({
"HaloExchangerNoComm"
:
HaloExchangerNoComm
,
"HaloExchangerAllGather"
:
HaloExchangerAllGather
,
"HaloExchangerSendRecv"
:
HaloExchangerSendRecv
,
})
apex/contrib/bottleneck/halo_exchangers.py
0 → 100644
View file @
778808eb
import
torch
import
torch.distributed
as
dist
from
torch
import
nn
import
nccl_p2p
as
inc
# Communication free halo exchanger.
# NB! This halo exchanger does not exchange halos with neighbors as it should, it merely swaps the inputs
# NB! This is only useful for performance testing.
# NB! Do not use for actual production runs
class
HaloExchanger
(
object
):
def
__init__
(
self
):
self
.
stream1
=
torch
.
cuda
.
Stream
()
self
.
stream2
=
torch
.
cuda
.
Stream
()
class
HaloExchangerNoComm
(
HaloExchanger
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
super
(
HaloExchangerNoComm
,
self
).
__init__
()
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
):
return
right_output_halo
,
left_output_halo
class
HaloExchangerAllGather
(
HaloExchanger
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
super
(
HaloExchangerAllGather
,
self
).
__init__
()
self
.
spatial_group_size
=
spatial_group_size
self
.
local_rank
=
rank
%
spatial_group_size
self
.
comm
=
comm
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
):
N
,
Hh
,
W
,
C
=
list
(
left_output_halo
.
shape
)
send_halos
=
torch
.
empty
((
N
,
2
*
Hh
,
W
,
C
),
dtype
=
left_output_halo
.
dtype
,
device
=
left_output_halo
.
device
)
send_halos
[:,:
Hh
,:,:].
copy_
(
left_output_halo
)
send_halos
[:,
Hh
:,:,:].
copy_
(
right_output_halo
)
all_halos
=
torch
.
empty
((
N
,
2
*
Hh
*
self
.
spatial_group_size
,
W
,
C
),
dtype
=
left_output_halo
.
dtype
,
device
=
left_output_halo
.
device
)
all_halos
=
[
all_halos
[:,
i
*
2
*
Hh
:(
i
+
1
)
*
2
*
Hh
,:,:]
for
i
in
range
(
self
.
spatial_group_size
)]
torch
.
distributed
.
all_gather
(
all_halos
,
send_halos
,
group
=
self
.
comm
,
no_copy
=
True
)
left_input_halo
=
all_halos
[(
self
.
spatial_group_size
+
self
.
local_rank
-
1
)
%
self
.
spatial_group_size
][:,
Hh
:,:,:]
right_input_halo
=
all_halos
[(
self
.
local_rank
+
1
)
%
self
.
spatial_group_size
][:,:
Hh
,:,:]
return
left_input_halo
,
right_input_halo
class
HaloExchangerSendRecv
(
HaloExchanger
):
def
__init__
(
self
,
world_size
,
spatial_group_size
,
rank
,
comm
):
super
(
HaloExchangerSendRecv
,
self
).
__init__
()
self
.
world_size
=
world_size
self
.
spatial_group_size
=
spatial_group_size
nccl_id
=
inc
.
get_unique_nccl_id
(
1
).
cuda
()
torch
.
distributed
.
broadcast
(
nccl_id
,
0
)
nccl_id
=
nccl_id
.
cpu
()
self
.
handle
=
inc
.
init_nccl_comm
(
nccl_id
,
rank
,
world_size
)
def
left_right_halo_exchange
(
self
,
left_output_halo
,
right_output_halo
):
left_input_halo
,
right_input_halo
=
inc
.
left_right_halo_exchange
(
self
.
handle
,
left_output_halo
,
right_output_halo
,
self
.
spatial_group_size
)
return
left_input_halo
,
right_input_halo
apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp
0 → 100644
View file @
778808eb
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nccl_p2p_cuda.cuh"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"get_unique_nccl_id"
,
&
apex
::
contrib
::
nccl_p2p
::
get_unique_nccl_id
,
"get_unique_nccl_id"
);
m
.
def
(
"init_nccl_comm"
,
&
apex
::
contrib
::
nccl_p2p
::
init_nccl_comm
,
"init_nccl_comm"
);
m
.
def
(
"nccl_send"
,
&
apex
::
contrib
::
nccl_p2p
::
nccl_send
,
"nccl_send"
);
m
.
def
(
"nccl_recv"
,
&
apex
::
contrib
::
nccl_p2p
::
nccl_recv
,
"nccl_recv"
);
m
.
def
(
"left_right_halo_exchange"
,
&
apex
::
contrib
::
nccl_p2p
::
left_right_halo_exchange
,
"left_right_halo_exchange"
);
m
.
def
(
"add_delay"
,
&
apex
::
contrib
::
nccl_p2p
::
add_delay
,
"add_delay"
);
}
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
0 → 100644
View file @
778808eb
#include <torch/extension.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <ATen/cuda/CUDAContext.h>
#include <list>
#include <cstdio>
#include <ctime>
#include <cassert>
#include "nccl.h"
/*
* This file implements a crude but effective mechanism for copying data between tenors owned by different ranks
* on the same machine using cudaMemcpyAsync peer-to-peer transfers.
*/
namespace
{
__global__
void
AddDelay_kernel
(
const
int
delay
,
int
*
counter
)
{
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
{
// waste time while doing something compiler can't predict, thus preventing it from optimizing away this code.
int
new_counter
=
0
;
double
elapsed
=
0
;
clock_t
start
=
clock
();
do
{
clock_t
now
=
clock
();
elapsed
=
(
double
)(
now
-
start
)
*
1e9
/
CLOCKS_PER_SEC
;
++
new_counter
;
}
while
(
elapsed
<
(
double
)
delay
);
*
counter
=
new_counter
;
}
}
class
NcclCommWrapper
{
private:
ncclComm_t
comm
;
int
rank
,
world_size
;
ncclDataType_t
get_nccl_type
(
at
::
Tensor
input
)
{
switch
(
input
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
return
ncclFloat16
;
case
at
::
ScalarType
::
Float
:
return
ncclFloat32
;
case
at
::
ScalarType
::
Double
:
return
ncclFloat64
;
case
at
::
ScalarType
::
Byte
:
return
ncclUint8
;
case
at
::
ScalarType
::
Char
:
return
ncclInt8
;
case
at
::
ScalarType
::
Int
:
return
ncclInt32
;
case
at
::
ScalarType
::
Long
:
return
ncclInt64
;
case
at
::
ScalarType
::
BFloat16
:
return
ncclBfloat16
;
default:
assert
(
false
);
}
}
public:
NcclCommWrapper
()
{
memset
(
&
comm
,
0
,
sizeof
(
ncclComm_t
));
rank
=
0
;
world_size
=
0
;
}
NcclCommWrapper
(
ncclUniqueId
id
,
int
my_rank
,
int
num_ranks
)
{
ncclCommInitRank
(
&
comm
,
num_ranks
,
id
,
my_rank
);
rank
=
my_rank
;
world_size
=
num_ranks
;
}
~
NcclCommWrapper
()
{
printf
(
"ncclCommDestroy()
\n
"
);
ncclCommDestroy
(
comm
);
}
void
send
(
at
::
Tensor
input
,
int
destination
)
{
ncclDataType_t
ncclType
=
get_nccl_type
(
input
);
AT_DISPATCH_ALL_TYPES_AND3
(
at
::
ScalarType
::
Bool
,
at
::
ScalarType
::
BFloat16
,
at
::
ScalarType
::
Half
,
input
.
scalar_type
(),
"nccl_send"
,
[
&
]()
{
size_t
count
=
sizeof
(
scalar_t
)
*
torch
::
numel
(
input
);
auto
input_ptr
=
input
.
data_ptr
<
scalar_t
>
();
ncclSend
(
input_ptr
,
count
,
ncclType
,
destination
,
comm
,
at
::
cuda
::
getCurrentCUDAStream
());
});
}
void
recv
(
at
::
Tensor
input
,
int
sender
)
{
ncclDataType_t
ncclType
=
get_nccl_type
(
input
);
AT_DISPATCH_ALL_TYPES_AND3
(
at
::
ScalarType
::
Bool
,
at
::
ScalarType
::
BFloat16
,
at
::
ScalarType
::
Half
,
input
.
scalar_type
(),
"nccl_send"
,
[
&
]()
{
size_t
count
=
sizeof
(
scalar_t
)
*
torch
::
numel
(
input
);
auto
input_ptr
=
input
.
data_ptr
<
scalar_t
>
();
ncclRecv
(
input_ptr
,
count
,
ncclType
,
sender
,
comm
,
at
::
cuda
::
getCurrentCUDAStream
());
});
}
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
int
group_size
)
{
// after halo exchange:
// left_output_halo of rank+1 ends up in right_input_halo of rank
// right_output_halo of rank-1 ends up in left_input_halo of rank
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
right_input_halo
=
torch
::
empty_like
(
left_output_halo
);
auto
left_input_halo
=
torch
::
empty_like
(
right_output_halo
);
ncclGroupStart
();
ncclDataType_t
ncclType
=
get_nccl_type
(
left_output_halo
);
// we use wrap-around ranks, so left_input_halo of rank 0 has right_output_halo of rank world_size-1 after exchange etc.
// this is technically speaking wasteful, but there is no benefit in having the edge ranks do less work than internal ranks.
int
group_rank
=
rank
%
group_size
;
int
group_index
=
rank
/
group_size
;
int
prev_rank
=
(
group_rank
+
group_size
-
1
)
%
group_size
;
int
next_rank
=
(
group_rank
+
1
)
%
group_size
;
prev_rank
=
prev_rank
+
group_index
*
group_size
;
next_rank
=
next_rank
+
group_index
*
group_size
;
size_t
left_n
=
torch
::
numel
(
left_output_halo
);
size_t
right_n
=
torch
::
numel
(
right_output_halo
);
if
(
group_rank
>
0
)
{
AT_DISPATCH_ALL_TYPES_AND3
(
at
::
ScalarType
::
Bool
,
at
::
ScalarType
::
BFloat16
,
at
::
ScalarType
::
Half
,
left_output_halo
.
scalar_type
(),
"left_halo_exch"
,
[
&
]()
{
// send left (to my_rank - 1)
ncclSend
(
left_output_halo
.
data_ptr
<
scalar_t
>
(),
left_n
,
ncclType
,
prev_rank
,
comm
,
stream
);
// receive left (from my_rank - 1)
ncclRecv
(
left_input_halo
.
data_ptr
<
scalar_t
>
(),
right_n
,
ncclType
,
prev_rank
,
comm
,
stream
);
});
}
if
(
group_rank
<
group_size
-
1
)
{
AT_DISPATCH_ALL_TYPES_AND3
(
at
::
ScalarType
::
Bool
,
at
::
ScalarType
::
BFloat16
,
at
::
ScalarType
::
Half
,
right_output_halo
.
scalar_type
(),
"right_halo_exch"
,
[
&
]()
{
// send right (to my_rank + 1 )
ncclSend
(
right_output_halo
.
data_ptr
<
scalar_t
>
(),
right_n
,
ncclType
,
next_rank
,
comm
,
stream
);
// receive right (from my_rank + 1)
ncclRecv
(
right_input_halo
.
data_ptr
<
scalar_t
>
(),
left_n
,
ncclType
,
next_rank
,
comm
,
stream
);
});
}
ncclGroupEnd
();
return
{
left_input_halo
,
right_input_halo
};
}
};
std
::
vector
<
NcclCommWrapper
>
nccl_comms
;
}
// end anonymous namespace
namespace
apex
{
namespace
contrib
{
namespace
nccl_p2p
{
at
::
Tensor
get_unique_nccl_id
(
int
n
)
{
ncclUniqueId
id
;
ncclGetUniqueId
(
&
id
);
auto
id_tensor
=
torch
::
empty
({
n
*
(
int
)
sizeof
(
ncclUniqueId
)},
torch
::
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCPU
).
requires_grad
(
false
));
auto
id_ptr
=
id_tensor
.
data_ptr
<
uint8_t
>
();
size_t
offset
=
0
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
ncclUniqueId
id
;
ncclGetUniqueId
(
&
id
);
memcpy
(
id_ptr
+
offset
,
&
id
,
sizeof
(
ncclUniqueId
));
offset
+=
sizeof
(
ncclUniqueId
);
}
return
id_tensor
;
}
int
init_nccl_comm
(
at
::
Tensor
unique_nccl_id
,
int
my_rank
,
int
num_ranks
)
{
ncclUniqueId
id
;
auto
unique_nccl_id_ptr
=
unique_nccl_id
.
data_ptr
<
uint8_t
>
();
memcpy
(
&
id
,
unique_nccl_id_ptr
,
sizeof
(
ncclUniqueId
));
NcclCommWrapper
*
comm
=
new
NcclCommWrapper
(
id
,
my_rank
,
num_ranks
);
int
handle
=
nccl_comms
.
size
();
nccl_comms
.
push_back
(
*
comm
);
comm
=
0L
;
return
handle
;
}
void
nccl_send
(
int
handle
,
at
::
Tensor
input
,
int
destination
)
{
assert
(
handle
>=
0
&&
handle
<
nccl_comms
.
size
());
class
NcclCommWrapper
communicator
=
nccl_comms
[
handle
];
communicator
.
send
(
input
,
destination
);
}
void
nccl_recv
(
int
handle
,
at
::
Tensor
input
,
int
sender
)
{
assert
(
handle
>=
0
&&
handle
<
nccl_comms
.
size
());
class
NcclCommWrapper
communicator
=
nccl_comms
[
handle
];
communicator
.
recv
(
input
,
sender
);
}
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
int
handle
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
int
group_size
)
{
assert
(
handle
>=
0
&&
handle
<
nccl_comms
.
size
());
class
NcclCommWrapper
&
communicator
=
nccl_comms
[
handle
];
return
communicator
.
left_right_halo_exchange
(
left_output_halo
,
right_output_halo
,
group_size
);
}
void
add_delay
(
int
delay
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
t
=
torch
::
empty
({
1
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
AddDelay_kernel
<<<
1
,
1
,
0
,
stream
>>>
(
delay
,
t
.
data_ptr
<
int
>
());
}
}}}
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh
0 → 100644
View file @
778808eb
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <torch/extension.h>
#ifndef _nccl_p2p_h_
#define _nccl_p2p_h_
namespace
apex
{
namespace
contrib
{
namespace
nccl_p2p
{
at
::
Tensor
get_unique_nccl_id
(
int
n
);
int
init_nccl_comm
(
at
::
Tensor
unique_nccl_id
,
int
my_rank
,
int
num_ranks
);
void
nccl_send
(
int
handle
,
at
::
Tensor
input
,
int
destination
);
void
nccl_recv
(
int
handle
,
at
::
Tensor
input
,
int
sender
);
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
int
handle
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
int
group_size
);
void
add_delay
(
int
delay
);
}}}
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment