Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
795a5e5b
Commit
795a5e5b
authored
Jul 29, 2022
by
hubertlu-tw
Browse files
Merge remote-tracking branch 'upstream/master' into IFU-master-2022-07-29
parents
016c8d4f
3c19f106
Changes
230
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1089 additions
and
57 deletions
+1089
-57
apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu
apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu
+6
-0
apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu
apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu
+97
-1
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu
...rc/multihead_attn/additive_masked_softmax_dropout_cuda.cu
+1
-1
apex/contrib/csrc/multihead_attn/dropout.cuh
apex/contrib/csrc/multihead_attn/dropout.cuh
+1
-1
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
...src/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
+1
-2
apex/contrib/csrc/multihead_attn/layer_norm.cuh
apex/contrib/csrc/multihead_attn/layer_norm.cuh
+7
-7
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
...ihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
+0
-27
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
...trib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
+1
-2
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
.../csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
+1
-2
apex/contrib/csrc/multihead_attn/softmax.cuh
apex/contrib/csrc/multihead_attn/softmax.cuh
+7
-5
apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp
apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp
+25
-0
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
+211
-0
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh
+45
-0
apex/contrib/csrc/peer_memory/peer_memory.cpp
apex/contrib/csrc/peer_memory/peer_memory.cpp
+29
-0
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
+529
-0
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
+48
-0
apex/contrib/csrc/transducer/transducer_joint_kernel.cu
apex/contrib/csrc/transducer/transducer_joint_kernel.cu
+1
-1
apex/contrib/fmha/fmha.py
apex/contrib/fmha/fmha.py
+10
-8
apex/contrib/focal_loss/__init__.py
apex/contrib/focal_loss/__init__.py
+9
-0
apex/contrib/focal_loss/focal_loss.py
apex/contrib/focal_loss/focal_loss.py
+60
-0
No files found.
apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu
View file @
795a5e5b
...
...
@@ -166,6 +166,12 @@ REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER
(
12800
,
bf16
,
bf16
,
bf16
,
fp32
,
5
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
12800
,
bf16
,
fp32
,
bf16
,
fp32
,
5
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
14336
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
14336
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
14336
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
14336
,
bf16
,
bf16
,
bf16
,
fp32
,
4
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
14336
,
bf16
,
fp32
,
bf16
,
fp32
,
4
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
15360
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
15360
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
4
,
4
,
4
);
REGISTER_BWD_LAUNCHER
(
15360
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
4
,
8
,
4
);
...
...
apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu
View file @
795a5e5b
...
...
@@ -67,9 +67,105 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params
}
REGISTER_FWD_LAUNCHER
(
768
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
768
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
768
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
768
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
768
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2304
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2304
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2304
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2304
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2304
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3840
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
3840
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
3840
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
3840
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
3840
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
4096
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
6144
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
6144
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
6144
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
6144
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
6144
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
10240
,
fp32
,
fp32
,
fp32
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
10240
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
10240
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
10240
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
10240
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
12288
,
fp32
,
fp32
,
fp32
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
12288
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
12288
,
fp16
,
fp32
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
12288
,
bf16
,
bf16
,
bf16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
12288
,
bf16
,
fp32
,
bf16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
12800
,
fp32
,
fp32
,
fp32
,
fp32
,
2
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
12800
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
12800
,
fp16
,
fp32
,
fp16
,
fp32
,
2
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
12800
,
bf16
,
bf16
,
bf16
,
fp32
,
2
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
12800
,
bf16
,
fp32
,
bf16
,
fp32
,
2
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
14336
,
fp32
,
fp32
,
fp32
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
14336
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
14336
,
fp16
,
fp32
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
14336
,
bf16
,
bf16
,
bf16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
14336
,
bf16
,
fp32
,
bf16
,
fp32
,
2
,
1
,
4
,
8
);
REGISTER_FWD_LAUNCHER
(
15360
,
fp32
,
fp32
,
fp32
,
fp32
,
2
,
1
,
4
,
8
);
REGISTER_FWD_LAUNCHER
(
15360
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
8
);
REGISTER_FWD_LAUNCHER
(
15360
,
fp16
,
fp32
,
fp16
,
fp32
,
2
,
1
,
4
,
8
);
REGISTER_FWD_LAUNCHER
(
15360
,
bf16
,
bf16
,
bf16
,
fp32
,
2
,
1
,
4
,
8
);
REGISTER_FWD_LAUNCHER
(
15360
,
bf16
,
fp32
,
bf16
,
fp32
,
2
,
1
,
4
,
8
);
REGISTER_FWD_LAUNCHER
(
16384
,
fp32
,
fp32
,
fp32
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
16384
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
16384
,
fp16
,
fp32
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
16384
,
bf16
,
bf16
,
bf16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
16384
,
bf16
,
fp32
,
bf16
,
fp32
,
2
,
1
,
4
,
16
);
...
...
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu
View file @
795a5e5b
...
...
@@ -110,4 +110,4 @@ torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
}
}
// namespace additive_mask_softmax_dropout
}
// namespace fused_softmax
}
// namespace multihead_attn
\ No newline at end of file
}
// namespace multihead_attn
apex/contrib/csrc/multihead_attn/dropout.cuh
View file @
795a5e5b
#pragma once
#include <ATen/ATen.h>
#if
!
def
ined(NEW
_GENERATOR_PATH
)
#ifdef
OLD
_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
...
...
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
View file @
795a5e5b
...
...
@@ -687,5 +687,4 @@ std::vector<torch::Tensor> bwd_cuda(
}
// end namespace rocblas_gemmex
}
// end namespace encdec_norm_add
}
// end namespace multihead_attn
}
// end namespace multihead_attn
\ No newline at end of file
apex/contrib/csrc/multihead_attn/layer_norm.cuh
View file @
795a5e5b
...
...
@@ -67,7 +67,7 @@ __device__ void cuWelfordMuSigma2(const T *__restrict__ vals, const int n1,
}
// intra-warp reductions
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
U
muB
=
WARP_SHFL
(
mu
,
srcLaneB
,
32
);
U
countB
=
WARP_SHFL
(
count
,
srcLaneB
,
32
);
U
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
,
32
);
...
...
@@ -108,7 +108,7 @@ __device__ void cuWelfordMuSigma2(const T *__restrict__ vals, const int n1,
// don't care about final value of count, we know count == n2
}
else
{
mu
=
WARP_SHFL
(
mu
,
0
,
32
);
sigma2
=
WARP_SHFL
(
sigma2
/
U
(
n2
),
0
,
32
);
sigma2
=
WARP_SHFL
(
sigma2
/
U
(
n2
),
0
,
32
);
}
}
}
...
...
@@ -158,7 +158,7 @@ __device__ void cuWelfordMuSigma2(const at::Half *__restrict__ vals,
}
// intra-warp reductions
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
float
muB
=
WARP_SHFL
(
mu
,
srcLaneB
,
32
);
float
countB
=
WARP_SHFL
(
count
,
srcLaneB
,
32
);
float
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
,
32
);
...
...
@@ -199,7 +199,7 @@ __device__ void cuWelfordMuSigma2(const at::Half *__restrict__ vals,
// don't care about final value of count, we know count == n2
}
else
{
mu
=
WARP_SHFL
(
mu
,
0
,
32
);
sigma2
=
WARP_SHFL
(
sigma2
/
float
(
n2
),
0
,
32
);
sigma2
=
WARP_SHFL
(
sigma2
/
float
(
n2
),
0
,
32
);
}
}
}
...
...
@@ -261,7 +261,7 @@ cuApplyLayerNorm(T *__restrict__ output_vals, U *__restrict__ mean,
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
//
for
(
int
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
U
mu
,
sigma2
;
...
...
@@ -475,7 +475,7 @@ cuComputeGradInput(const T *__restrict__ dout, const T *__restrict__ dout_resid,
const
T
*
__restrict__
input
,
const
int
n1
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
,
U
epsilon
,
const
T
*
gamma
,
T
*
grad_input
)
{
for
(
int
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
U
sum_loss1
=
U
(
0
);
U
sum_loss2
=
U
(
0
);
const
U
c_mean
=
mean
[
i1
];
...
...
@@ -521,7 +521,7 @@ cuComputeGradInput(const T *__restrict__ dout, const T *__restrict__ dout_resid,
}
}
// intra-warp reductions
for
(
int
mask
=
blockDim
.
x
/
2
;
mask
>
0
;
mask
/=
2
)
{
for
(
int
mask
=
blockDim
.
x
/
2
;
mask
>
0
;
mask
/=
2
)
{
sum_loss1
+=
WARP_SHFL_XOR
(
sum_loss1
,
mask
,
32
);
sum_loss2
+=
WARP_SHFL_XOR
(
sum_loss2
,
mask
,
32
);
}
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
View file @
795a5e5b
...
...
@@ -19,7 +19,6 @@ namespace multihead_attn {
namespace
self_bias_additive_mask
{
namespace
rocblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
is_training
,
int
heads
,
...
...
@@ -50,32 +49,6 @@ std::vector<torch::Tensor> fwd_cuda(
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto
act_options
=
inputs
.
options
().
requires_grad
(
false
);
auto
mask_options
=
act_options
.
dtype
(
torch
::
kUInt8
);
torch
::
Tensor
input_lin_results
=
torch
::
empty
({
q_seq_len
,
sequences
,
output_lin_dim
},
act_options
);
torch
::
Tensor
bmm1_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_mask
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
mask_options
);
torch
::
Tensor
matmul2_results
=
torch
::
empty
({
q_seq_len
,
attn_batches
,
head_dim
},
act_options
);
torch
::
Tensor
outputs
=
torch
::
empty_like
(
inputs
,
act_options
);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void
*
q_lin_results_ptr
=
static_cast
<
void
*>
(
input_lin_results
.
data_ptr
());
void
*
k_lin_results_ptr
=
static_cast
<
void
*>
(
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
head_dim
);
void
*
v_lin_results_ptr
=
static_cast
<
void
*>
(
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
2
*
head_dim
);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
View file @
795a5e5b
...
...
@@ -501,5 +501,4 @@ std::vector<torch::Tensor> bwd_cuda(
}
// end namespace rocblas_gemmex
}
// end namespace self
}
// end namespace multihead_attn
}
// end namespace multihead_attn
\ No newline at end of file
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
View file @
795a5e5b
...
...
@@ -577,5 +577,4 @@ std::vector<torch::Tensor> bwd_cuda(
}
// end namespace rocblas_gemmex
}
// end namespace self_norm_add
}
// end namespace multihead_attn
}
// end namespace multihead_attn
\ No newline at end of file
apex/contrib/csrc/multihead_attn/softmax.cuh
View file @
795a5e5b
...
...
@@ -3,7 +3,7 @@
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <curand_kernel.h>
#if
!
def
ined(NEW
_GENERATOR_PATH
)
#ifdef
OLD
_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
...
...
@@ -1593,11 +1593,13 @@ int log2_ceil_native(int value) {
}
template
<
typename
T
>
__device__
__forceinline__
T
WARP_SHFL_XOR_NATIVE
(
T
value
,
int
laneMask
,
int
width
=
warpSize
,
unsigned
int
mask
=
0xffffffff
)
{
#if CUDA_VERSION >= 9000 && !defined(__HIP_PLATFORM_HCC__)
return
__shfl_xor_sync
(
mask
,
value
,
laneMask
,
width
);
__device__
__forceinline__
T
WARP_SHFL_XOR_NATIVE
(
T
value
,
int
laneMask
,
int
width
=
warpSize
,
unsigned
int
mask
=
0xffffffff
)
{
#if CUDA_VERSION >= 9000
return
__shfl_xor_sync
(
mask
,
value
,
laneMask
,
width
);
#else
return
__shfl_xor
(
value
,
laneMask
,
width
);
return
__shfl_xor
(
value
,
laneMask
,
width
);
#endif
}
...
...
apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp
0 → 100644
View file @
795a5e5b
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nccl_p2p_cuda.cuh"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"get_unique_nccl_id"
,
&
apex
::
contrib
::
nccl_p2p
::
get_unique_nccl_id
,
"get_unique_nccl_id"
);
m
.
def
(
"init_nccl_comm"
,
&
apex
::
contrib
::
nccl_p2p
::
init_nccl_comm
,
"init_nccl_comm"
);
m
.
def
(
"left_right_halo_exchange_inplace"
,
&
apex
::
contrib
::
nccl_p2p
::
left_right_halo_exchange_inplace
,
"left_right_halo_exchange_inplace"
);
m
.
def
(
"left_right_halo_exchange"
,
&
apex
::
contrib
::
nccl_p2p
::
left_right_halo_exchange
,
"left_right_halo_exchange"
);
m
.
def
(
"add_delay"
,
&
apex
::
contrib
::
nccl_p2p
::
add_delay
,
"add_delay"
);
}
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
0 → 100644
View file @
795a5e5b
#include <torch/extension.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <ATen/cuda/CUDAContext.h>
#include <list>
#include <cstdio>
#include <ctime>
#include <cassert>
#include "nccl.h"
/*
* This file implements a crude but effective mechanism for copying data between tenors owned by different ranks
* on the same machine using cudaMemcpyAsync peer-to-peer transfers.
*/
namespace
{
__global__
void
AddDelay_kernel
(
const
int
delay
,
int
*
counter
)
{
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
{
// waste time while doing something compiler can't predict, thus preventing it from optimizing away this code.
int
new_counter
=
0
;
double
elapsed
=
0
;
clock_t
start
=
clock
();
do
{
clock_t
now
=
clock
();
elapsed
=
(
double
)(
now
-
start
)
*
1e9
/
CLOCKS_PER_SEC
;
++
new_counter
;
}
while
(
elapsed
<
(
double
)
delay
);
*
counter
=
new_counter
;
}
}
class
NcclCommWrapper
{
private:
ncclComm_t
comm
;
int
rank
,
world_size
;
ncclDataType_t
get_nccl_type
(
at
::
Tensor
input
)
{
switch
(
input
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
return
ncclFloat16
;
case
at
::
ScalarType
::
Float
:
return
ncclFloat32
;
case
at
::
ScalarType
::
Double
:
return
ncclFloat64
;
case
at
::
ScalarType
::
Byte
:
return
ncclUint8
;
case
at
::
ScalarType
::
Char
:
return
ncclInt8
;
case
at
::
ScalarType
::
Int
:
return
ncclInt32
;
case
at
::
ScalarType
::
Long
:
return
ncclInt64
;
case
at
::
ScalarType
::
BFloat16
:
return
ncclBfloat16
;
default:
assert
(
false
);
}
}
public:
NcclCommWrapper
()
{
memset
(
&
comm
,
0
,
sizeof
(
ncclComm_t
));
rank
=
0
;
world_size
=
0
;
}
NcclCommWrapper
(
ncclUniqueId
id
,
int
my_rank
,
int
num_ranks
)
{
ncclCommInitRank
(
&
comm
,
num_ranks
,
id
,
my_rank
);
rank
=
my_rank
;
world_size
=
num_ranks
;
}
~
NcclCommWrapper
()
{
printf
(
"ncclCommDestroy()
\n
"
);
ncclCommDestroy
(
comm
);
}
void
left_right_halo_exchange_inplace
(
int
left_rank
,
int
right_rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
right_input_halo
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
ncclGroupStart
();
ncclDataType_t
ncclType
=
get_nccl_type
(
left_output_halo
);
bool
left_zero
=
(
left_rank
<
0
);
bool
right_zero
=
(
right_rank
<
0
);
size_t
left_n
=
torch
::
numel
(
left_output_halo
);
size_t
right_n
=
torch
::
numel
(
right_output_halo
);
assert
(
left_n
>
0
&&
left_n
==
right_n
);
if
(
left_zero
)
{
left_input_halo
.
zero_
();
}
else
{
AT_DISPATCH_ALL_TYPES_AND3
(
at
::
ScalarType
::
Bool
,
at
::
ScalarType
::
BFloat16
,
at
::
ScalarType
::
Half
,
left_output_halo
.
scalar_type
(),
"left_halo_exch"
,
[
&
]()
{
// send left (to my_rank - 1)
ncclSend
(
left_output_halo
.
data_ptr
<
scalar_t
>
(),
left_n
,
ncclType
,
left_rank
,
comm
,
stream
);
// receive left (from my_rank - 1)
ncclRecv
(
left_input_halo
.
data_ptr
<
scalar_t
>
(),
right_n
,
ncclType
,
left_rank
,
comm
,
stream
);
});
}
if
(
right_zero
)
{
right_input_halo
.
zero_
();
}
else
{
AT_DISPATCH_ALL_TYPES_AND3
(
at
::
ScalarType
::
Bool
,
at
::
ScalarType
::
BFloat16
,
at
::
ScalarType
::
Half
,
right_output_halo
.
scalar_type
(),
"right_halo_exch"
,
[
&
]()
{
// send right (to my_rank + 1 )
ncclSend
(
right_output_halo
.
data_ptr
<
scalar_t
>
(),
right_n
,
ncclType
,
right_rank
,
comm
,
stream
);
// receive right (from my_rank + 1)
ncclRecv
(
right_input_halo
.
data_ptr
<
scalar_t
>
(),
left_n
,
ncclType
,
right_rank
,
comm
,
stream
);
});
}
ncclGroupEnd
();
}
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
int
left_rank
,
int
right_rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
)
{
// after halo exchange:
// left_output_halo of rank+1 ends up in right_input_halo of rank
// right_output_halo of rank-1 ends up in left_input_halo of rank
auto
right_input_halo
=
torch
::
empty_like
(
left_output_halo
);
auto
left_input_halo
=
torch
::
empty_like
(
right_output_halo
);
left_right_halo_exchange_inplace
(
left_rank
,
right_rank
,
left_output_halo
,
right_output_halo
,
left_input_halo
,
right_input_halo
);
return
{
left_input_halo
,
right_input_halo
};
}
};
class
ManagedObjects
{
public:
ManagedObjects
()
{
}
~
ManagedObjects
()
{
for
(
auto
it
=
_nccl_comms
.
begin
();
it
!=
_nccl_comms
.
end
();
++
it
)
{
delete
*
it
;
}
}
int
add_comm
(
NcclCommWrapper
*
comm
)
{
int
handle
=
_nccl_comms
.
size
();
_nccl_comms
.
push_back
(
comm
);
return
handle
;
}
NcclCommWrapper
&
get_comm
(
int
handle
)
{
assert
(
handle
>=
0
&&
handle
<
_nccl_comms
.
size
());
return
*
_nccl_comms
[
handle
];
}
private:
std
::
vector
<
NcclCommWrapper
*>
_nccl_comms
;
};
class
ManagedObjects
mo
;
}
// end anonymous namespace
namespace
apex
{
namespace
contrib
{
namespace
nccl_p2p
{
at
::
Tensor
get_unique_nccl_id
(
int
n
)
{
ncclUniqueId
id
;
ncclGetUniqueId
(
&
id
);
auto
id_tensor
=
torch
::
empty
({
n
,(
int
)
sizeof
(
ncclUniqueId
)},
torch
::
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCPU
).
requires_grad
(
false
));
auto
id_ptr
=
id_tensor
.
data_ptr
<
uint8_t
>
();
size_t
offset
=
0
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
ncclUniqueId
id
;
ncclGetUniqueId
(
&
id
);
memcpy
(
id_ptr
+
offset
,
&
id
,
sizeof
(
ncclUniqueId
));
offset
+=
sizeof
(
ncclUniqueId
);
}
return
id_tensor
;
}
int
init_nccl_comm
(
at
::
Tensor
unique_nccl_id
,
int
my_rank
,
int
num_ranks
)
{
ncclUniqueId
id
;
auto
unique_nccl_id_ptr
=
unique_nccl_id
.
data_ptr
<
uint8_t
>
();
memcpy
(
&
id
,
unique_nccl_id_ptr
,
sizeof
(
ncclUniqueId
));
NcclCommWrapper
*
comm
=
new
NcclCommWrapper
(
id
,
my_rank
,
num_ranks
);
int
handle
=
mo
.
add_comm
(
comm
);
comm
=
0L
;
return
handle
;
}
void
left_right_halo_exchange_inplace
(
int
handle
,
int
left_rank
,
int
right_rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
right_input_halo
)
{
class
NcclCommWrapper
&
communicator
=
mo
.
get_comm
(
handle
);
return
communicator
.
left_right_halo_exchange_inplace
(
left_rank
,
right_rank
,
left_output_halo
,
right_output_halo
,
left_input_halo
,
right_input_halo
);
}
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
int
handle
,
int
left_rank
,
int
right_rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
)
{
class
NcclCommWrapper
&
communicator
=
mo
.
get_comm
(
handle
);
return
communicator
.
left_right_halo_exchange
(
left_rank
,
right_rank
,
left_output_halo
,
right_output_halo
);
}
void
add_delay
(
int
delay
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
t
=
torch
::
empty
({
1
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
AddDelay_kernel
<<<
1
,
1
,
0
,
stream
>>>
(
delay
,
t
.
data_ptr
<
int
>
());
}
}}}
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh
0 → 100644
View file @
795a5e5b
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <torch/extension.h>
#ifndef _nccl_p2p_h_
#define _nccl_p2p_h_
namespace
apex
{
namespace
contrib
{
namespace
nccl_p2p
{
at
::
Tensor
get_unique_nccl_id
(
int
n
);
int
init_nccl_comm
(
at
::
Tensor
unique_nccl_id
,
int
my_rank
,
int
num_ranks
);
void
left_right_halo_exchange_inplace
(
int
handle
,
int
left_rank
,
int
right_rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
right_input_halo
);
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
int
handle
,
int
left_rank
,
int
right_rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
);
void
add_delay
(
int
delay
);
}}}
#endif
apex/contrib/csrc/peer_memory/peer_memory.cpp
0 → 100644
View file @
795a5e5b
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "peer_memory_cuda.cuh"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"allocate_raw"
,
&
apex
::
contrib
::
peer_memory
::
allocate_raw
,
"allocate_raw"
);
m
.
def
(
"free_raw"
,
&
apex
::
contrib
::
peer_memory
::
free_raw
,
"free_raw"
);
m
.
def
(
"zero"
,
&
apex
::
contrib
::
peer_memory
::
zero
,
"zero"
);
m
.
def
(
"get_raw_ipc_address"
,
&
apex
::
contrib
::
peer_memory
::
get_raw_ipc_address
,
"get_raw_ipc_address"
);
m
.
def
(
"get_raw_peers"
,
&
apex
::
contrib
::
peer_memory
::
get_raw_peers
,
"get_raw_peers"
);
m
.
def
(
"blob_view_half"
,
&
apex
::
contrib
::
peer_memory
::
blob_view_half
,
"blob_view_half"
);
m
.
def
(
"blob_view_float"
,
&
apex
::
contrib
::
peer_memory
::
blob_view_float
,
"blob_view_float"
);
m
.
def
(
"blob_view_int"
,
&
apex
::
contrib
::
peer_memory
::
blob_view_int
,
"blob_view_int"
);
m
.
def
(
"push_pull_halos_1d"
,
&
apex
::
contrib
::
peer_memory
::
push_pull_halos_1d
,
"push_pull_halos_1d"
);
}
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
0 → 100644
View file @
795a5e5b
#include <torch/extension.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <ATen/cuda/CUDAContext.h>
#include <list>
#include <cstdio>
#include <cassert>
#include <cuda_runtime_api.h>
#include <cooperative_groups.h>
#include "nccl.h"
namespace
cg
=
cooperative_groups
;
#define CUDACHECK(cmd) do { \
cudaError_t err = cmd; \
if( err != cudaSuccess ) { \
char hostname[1024]; \
gethostname(hostname, 1024); \
printf("%s: CUDA failure %s:%d '%s'\n", \
hostname, \
__FILE__,__LINE__,cudaGetErrorString(err)); \
} \
} while(0)
namespace
{
/* Basic deleter function for from_blob function.
void deleter(void* ptr)
{
printf("deleter(ptr=%p)\n",ptr);
cudaFree(ptr);
}
*/
template
<
class
T
>
at
::
Tensor
blob_view
(
T
*
raw_ptr
,
std
::
vector
<
int64_t
>
shape
,
const
at
::
TensorOptions
&
options
,
bool
channels_last
)
{
size_t
size
=
1
;
std
::
vector
<
int64_t
>
strides
(
shape
.
size
());
if
(
channels_last
)
{
assert
(
shape
.
size
()
==
4
);
strides
[
0
]
=
shape
[
1
]
*
shape
[
2
]
*
shape
[
3
];
strides
[
1
]
=
1
;
strides
[
2
]
=
shape
[
1
]
*
shape
[
3
];
strides
[
3
]
=
shape
[
1
];
}
else
{
int
idx
=
strides
.
size
();
for
(
auto
it
=
shape
.
rbegin
();
it
!=
shape
.
rend
();
++
it
)
{
strides
[
--
idx
]
=
size
;
size
*=
*
it
;
}
}
size
*=
sizeof
(
T
);
// TODO: Implement dynamic reuse of pooled peer memory.
// We provide no deleter function because all peer memory allocations are static in this implementation.
return
torch
::
from_blob
((
void
*
)
raw_ptr
,
shape
,
strides
,
0L
,
options
);
}
void
tensor_shape
(
at
::
Tensor
t
,
bool
explicit_nhwc
,
int
&
N
,
int
&
C
,
int
&
H
,
int
&
W
)
{
if
(
t
.
dim
()
==
3
)
{
N
=
1
;
if
(
explicit_nhwc
)
{
C
=
t
.
size
(
2
);
H
=
t
.
size
(
0
);
W
=
t
.
size
(
1
);
}
else
{
C
=
t
.
size
(
0
);
H
=
t
.
size
(
1
);
W
=
t
.
size
(
2
);
}
}
else
if
(
t
.
dim
()
==
4
)
{
if
(
explicit_nhwc
)
{
N
=
t
.
size
(
0
);
C
=
t
.
size
(
3
);
H
=
t
.
size
(
1
);
W
=
t
.
size
(
2
);
}
else
{
N
=
t
.
size
(
0
);
C
=
t
.
size
(
1
);
H
=
t
.
size
(
2
);
W
=
t
.
size
(
3
);
}
}
else
{
printf
(
"%s;%d - t.dim() must be either 3 or 4 (was %d)
\n
"
,
__FILE__
,
__LINE__
,
t
.
dim
());
assert
(
t
.
dim
()
==
3
||
t
.
dim
()
==
4
);
}
}
void
tensor_strides
(
at
::
Tensor
t
,
bool
explicit_nhwc
,
int
&
stride_N
,
int
&
stride_C
,
int
&
stride_H
,
int
&
stride_W
)
{
if
(
t
.
dim
()
==
3
)
{
if
(
explicit_nhwc
)
{
stride_C
=
t
.
stride
(
2
);
stride_H
=
t
.
stride
(
0
);
stride_W
=
t
.
stride
(
1
);
}
else
{
stride_C
=
t
.
stride
(
0
);
stride_H
=
t
.
stride
(
1
);
stride_W
=
t
.
stride
(
2
);
}
stride_N
=
t
.
size
(
0
)
*
t
.
size
(
1
)
*
t
.
size
(
2
);
}
else
if
(
t
.
dim
()
==
4
)
{
if
(
explicit_nhwc
)
{
stride_N
=
t
.
stride
(
0
);
stride_C
=
t
.
stride
(
3
);
stride_H
=
t
.
stride
(
1
);
stride_W
=
t
.
stride
(
2
);
}
else
{
stride_N
=
t
.
stride
(
0
);
stride_C
=
t
.
stride
(
1
);
stride_H
=
t
.
stride
(
2
);
stride_W
=
t
.
stride
(
3
);
}
}
else
{
printf
(
"%s;%d - t.dim() must be either 3 or 4 (was %d)
\n
"
,
__FILE__
,
__LINE__
,
t
.
dim
());
assert
(
t
.
dim
()
==
3
||
t
.
dim
()
==
4
);
}
}
template
<
class
T
,
bool
is_HWC
>
__device__
void
strided_copy_kernel
(
T
*
dst
,
const
int
dst_stride_C
,
const
int
dst_stride_H
,
const
int
dst_stride_W
,
const
T
*
src
,
const
int
src_stride_C
,
const
int
src_stride_H
,
const
int
src_stride_W
,
const
int
NC
,
const
int
NH
,
const
int
NW
)
{
size_t
tot_num_threads
=
gridDim
.
x
*
blockDim
.
x
;
size_t
thread_id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
count
=
NC
*
NH
*
NW
;
for
(
size_t
i
=
thread_id
;
i
<
count
;
i
+=
tot_num_threads
)
{
size_t
c
,
h
,
w
;
if
(
is_HWC
)
{
c
=
i
%
NC
;
w
=
i
/
NC
;
h
=
w
/
NW
;
w
=
w
%
NW
;
}
else
{
w
=
i
%
NW
;
h
=
i
/
NW
;
c
=
h
/
NH
;
h
=
h
%
NH
;
}
size_t
dst_off
=
c
*
dst_stride_C
+
h
*
dst_stride_H
+
w
*
dst_stride_W
;
size_t
src_off
=
c
*
src_stride_C
+
h
*
src_stride_H
+
w
*
src_stride_W
;
dst
[
dst_off
]
=
src
[
src_off
];
}
}
__device__
void
checked_signal
(
volatile
int
*
signal1_flag
,
volatile
int
*
signal2_flag
,
const
int
v1
,
const
int
v2
,
const
int
v3
,
const
int
v4
)
{
cg
::
this_grid
().
sync
();
bool
is_main_thread
=
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
?
true
:
false
;
if
(
is_main_thread
)
{
// flush all writes to global memory
__threadfence_system
();
// wait for top or bottom neighbor to clear signal
register
int
r1
,
r2
,
r3
,
r4
;
bool
top_zeroed
=
false
,
btm_zeroed
=
false
,
top_done
=
false
,
btm_done
=
false
;
do
{
do
{
if
(
!
top_zeroed
)
{
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal1_flag
)
:
"memory"
);
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
top_zeroed
=
true
;
}
if
(
!
btm_zeroed
)
{
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal2_flag
)
:
"memory"
);
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
btm_zeroed
=
true
;
}
}
while
((
top_zeroed
==
top_done
)
&&
(
btm_zeroed
==
btm_done
));
if
(
!
top_done
&&
top_zeroed
)
{
// signal to top neighbor my output is ready
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal1_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
top_done
=
true
;
}
if
(
!
btm_done
&&
btm_zeroed
)
{
// signal to bottom neighbor my output is ready
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal2_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
btm_done
=
true
;
}
}
while
(
!
top_done
||
!
btm_done
);
}
}
__device__
void
wait_for
(
volatile
int
*
wait_flag
,
const
int
v1
,
const
int
v2
,
const
int
v3
,
const
int
v4
)
{
bool
is_main_thread
=
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
?
true
:
false
;
if
(
is_main_thread
)
{
register
int
r1
,
r2
,
r3
,
r4
;
// wait for senders to signal their output is read
do
{
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
wait_flag
)
:
"memory"
);
}
while
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
);
}
cg
::
this_grid
().
sync
();
// all threads wait for main
}
__device__
void
clear_flag
(
volatile
int
*
wait_flag
)
{
cg
::
this_grid
().
sync
();
// wait for all threads in kernel to finish
bool
is_main_thread
=
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
?
true
:
false
;
if
(
is_main_thread
)
{
register
int
r1
,
r2
,
r3
,
r4
;
r1
=
0
;
r2
=
0
;
r3
=
0
;
r4
=
0
;
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
wait_flag
),
"r"
(
r1
),
"r"
(
r2
),
"r"
(
r3
),
"r"
(
r4
)
:
"memory"
);
}
}
template
<
class
T
,
bool
is_HWC
>
#if __CUDA_ARCH__ >= 700
__launch_bounds__
(
128
,
16
)
#endif
__global__
void
push_pull_halos_1d_kernel
(
// top halo,
const
T
*
toh
,
int
toh_stride_C
,
int
toh_stride_H
,
int
toh_stride_W
,
// top output halo
T
*
tox
,
int
tox_stride_C
,
int
tox_stride_H
,
int
tox_stride_W
,
// top output tx buffer
T
*
tix
,
int
tix_stride_C
,
int
tix_stride_H
,
int
tix_stride_W
,
// top input tx buffer
T
*
tih
,
int
tih_stride_C
,
int
tih_stride_H
,
int
tih_stride_W
,
// top input halo
// btm halo
const
T
*
boh
,
int
boh_stride_C
,
int
boh_stride_H
,
int
boh_stride_W
,
// btm output halo
T
*
box
,
int
box_stride_C
,
int
box_stride_H
,
int
box_stride_W
,
// btm output tx buffer
T
*
bix
,
int
bix_stride_C
,
int
bix_stride_H
,
int
bix_stride_W
,
// btm input tx buffer
T
*
bih
,
int
bih_stride_C
,
int
bih_stride_H
,
int
bih_stride_W
,
// btm input halo
// dimensions
int
NC
,
int
NH
,
int
NW
,
// signals
int
*
signal1_flag
,
int
*
signal2_flag
,
int
*
wait1_flag
,
int
*
wait2_flag
)
{
// push top output halo to transfer buffer
strided_copy_kernel
<
T
,
is_HWC
>
(
tox
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
,
toh
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
,
NC
,
NH
,
NW
);
// push btm output halo to transfer buffer
strided_copy_kernel
<
T
,
is_HWC
>
(
box
,
box_stride_C
,
box_stride_H
,
box_stride_W
,
boh
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
,
NC
,
NH
,
NW
);
// signal to top and btm neigbhbors that output halos are ready to be read
// the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values
checked_signal
(
signal1_flag
,
signal2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
// pull top halo from transfer buffer in peer memory to input
wait_for
(
wait1_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
strided_copy_kernel
<
T
,
is_HWC
>
(
tih
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
,
tix
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
,
NC
,
NH
,
NW
);
clear_flag
(
wait1_flag
);
// pull btm halo from transfer buffer in peer memory to input
wait_for
(
wait2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
strided_copy_kernel
<
T
,
is_HWC
>
(
bih
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
,
bix
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
,
NC
,
NH
,
NW
);
clear_flag
(
wait2_flag
);
}
__global__
void
delay_kernel
(
int
delay_nanoseconds
,
int
*
counter
)
{
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
{
// waste time while doing something compiler can't predict, thus preventing it from optimizing away this code.
int
new_counter
=
0
;
double
elapsed
=
0
;
clock_t
start
=
clock
();
do
{
clock_t
now
=
clock
();
elapsed
=
(
double
)(
now
-
start
)
*
1e9
/
CLOCKS_PER_SEC
;
++
new_counter
;
}
while
(
elapsed
<
(
double
)
delay_nanoseconds
);
*
counter
=
new_counter
;
}
}
}
namespace
apex
{
namespace
contrib
{
namespace
peer_memory
{
int64_t
allocate_raw
(
int64_t
size
)
{
float
*
ptr
=
0L
;
cudaMalloc
(
&
ptr
,
size
);
cudaMemset
(
ptr
,
0
,
size
);
return
(
int64_t
)
ptr
;
}
void
free_raw
(
int64_t
raw
)
{
cudaFree
((
void
*
)
raw
);
}
void
zero
(
int64_t
raw
,
int64_t
size
)
{
cudaMemset
((
void
*
)
raw
,
0
,
size
);
}
at
::
Tensor
get_raw_ipc_address
(
int64_t
raw
)
{
cudaIpcMemHandle_t
mem_handle
;
CUDACHECK
(
cudaIpcGetMemHandle
(
&
mem_handle
,
(
void
*
)
raw
)
);
const
int
n
=
sizeof
(
cudaIpcMemHandle_t
);
auto
address_tensor
=
torch
::
empty
({
n
},
torch
::
dtype
(
torch
::
kUInt8
));
auto
address_tensor_p
=
address_tensor
.
data_ptr
<
uint8_t
>
();
memcpy
(
address_tensor_p
,
(
uint8_t
*
)
&
mem_handle
,
n
);
return
address_tensor
;
}
std
::
vector
<
int64_t
>
get_raw_peers
(
at
::
Tensor
ipc_addresses
,
int
peer_rank
,
int64_t
raw
)
{
int
peer_group_size
=
ipc_addresses
.
size
(
0
);
std
::
vector
<
int64_t
>
results
(
peer_group_size
);
for
(
int
i
=
0
;
i
<
peer_group_size
;
++
i
)
{
if
(
i
!=
peer_rank
)
{
cudaIpcMemHandle_t
mem_handle
;
memcpy
(
&
mem_handle
,
ipc_addresses
.
index
({
i
}).
data_ptr
<
uint8_t
>
(),
sizeof
(
cudaIpcMemHandle_t
));
void
*
p
=
0L
;
CUDACHECK
(
cudaIpcOpenMemHandle
((
void
**
)
&
p
,
mem_handle
,
cudaIpcMemLazyEnablePeerAccess
)
);
results
[
i
]
=
(
int64_t
)
p
;
}
else
{
results
[
i
]
=
(
int64_t
)
raw
;
}
}
return
results
;
}
at
::
Tensor
blob_view_half
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
)
{
return
blob_view
<
at
::
Half
>
((
at
::
Half
*
)
raw
,
shape
,
torch
::
dtype
(
torch
::
kFloat16
).
device
(
torch
::
kCUDA
),
channels_last
);
}
at
::
Tensor
blob_view_float
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
)
{
return
blob_view
<
float
>
((
float
*
)
raw
,
shape
,
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
),
channels_last
);
}
at
::
Tensor
blob_view_int
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
)
{
return
blob_view
<
int
>
((
int
*
)
raw
,
shape
,
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
),
channels_last
);
}
void
push_pull_halos_1d
(
bool
diagnostics
,
bool
explicit_nhwc
,
int
numSM
,
// number of SMs to use
at
::
Tensor
top_out_halo
,
// top output halo in sender device memory
at
::
Tensor
top_out_tx
,
// top output transfer buffer in sender peer pool memory
at
::
Tensor
top_inp_tx
,
// top input transfer buffer in top neighbor peer pool memory
at
::
Tensor
top_inp_halo
,
// top input halo in receiver device memory
at
::
Tensor
btm_out_halo
,
// btm output halo in sender device memory
at
::
Tensor
btm_out_tx
,
// btm output transfer buffer in sender peer pool memory
at
::
Tensor
btm_inp_tx
,
// btm input transfer buffer in btm neighbor peer pool memory
at
::
Tensor
btm_inp_halo
,
// btm input halo in receiver device memory
at
::
Tensor
top_signal
,
// top input signal in receiver device memory
at
::
Tensor
btm_signal
,
// btm input signal in receiver device memory
at
::
Tensor
waits
// top and btm signals for this rank
)
{
// basic checks of inputs
TORCH_CHECK
(
top_out_halo
.
is_cuda
());
TORCH_CHECK
(
top_out_tx
.
is_cuda
());
TORCH_CHECK
(
top_inp_tx
.
is_cuda
());
TORCH_CHECK
(
top_inp_halo
.
is_cuda
());
TORCH_CHECK
(
btm_out_halo
.
is_cuda
());
TORCH_CHECK
(
btm_out_tx
.
is_cuda
());
TORCH_CHECK
(
btm_inp_tx
.
is_cuda
());
TORCH_CHECK
(
btm_inp_halo
.
is_cuda
());
TORCH_CHECK
(
top_signal
.
is_cuda
());
TORCH_CHECK
(
btm_signal
.
is_cuda
());
TORCH_CHECK
(
waits
.
is_cuda
());
// shapes and strides
int
toh_N
,
toh_C
,
toh_H
,
toh_W
;
tensor_shape
(
top_out_halo
,
explicit_nhwc
,
toh_N
,
toh_C
,
toh_H
,
toh_W
);
int
tox_N
,
tox_C
,
tox_H
,
tox_W
;
tensor_shape
(
top_out_tx
,
explicit_nhwc
,
tox_N
,
tox_C
,
tox_H
,
tox_W
);
int
tix_N
,
tix_C
,
tix_H
,
tix_W
;
tensor_shape
(
top_inp_tx
,
explicit_nhwc
,
tix_N
,
tix_C
,
tix_H
,
tix_W
);
int
tih_N
,
tih_C
,
tih_H
,
tih_W
;
tensor_shape
(
top_inp_halo
,
explicit_nhwc
,
tih_N
,
tih_C
,
tih_H
,
tih_W
);
TORCH_CHECK
(
(
toh_N
==
tox_N
&&
tox_N
==
tix_N
&&
tix_N
==
tih_N
)
&&
(
toh_C
==
tox_C
&&
tox_C
==
tix_C
&&
tix_C
==
tih_C
)
&&
(
toh_H
==
tox_H
&&
tox_H
==
tix_H
&&
tix_H
==
tih_H
)
&&
(
toh_W
==
tox_W
&&
tox_W
==
tix_W
&&
tix_W
==
tih_W
));
int
boh_N
,
boh_C
,
boh_H
,
boh_W
;
tensor_shape
(
btm_out_halo
,
explicit_nhwc
,
boh_N
,
boh_C
,
boh_H
,
boh_W
);
int
box_N
,
box_C
,
box_H
,
box_W
;
tensor_shape
(
btm_out_tx
,
explicit_nhwc
,
box_N
,
box_C
,
box_H
,
box_W
);
int
bix_N
,
bix_C
,
bix_H
,
bix_W
;
tensor_shape
(
btm_inp_tx
,
explicit_nhwc
,
bix_N
,
bix_C
,
bix_H
,
bix_W
);
int
bih_N
,
bih_C
,
bih_H
,
bih_W
;
tensor_shape
(
btm_inp_halo
,
explicit_nhwc
,
bih_N
,
bih_C
,
bih_H
,
bih_W
);
TORCH_CHECK
(
(
boh_N
==
box_N
&&
box_N
==
bix_N
&&
bix_N
==
bih_N
)
&&
(
boh_C
==
box_C
&&
box_C
==
bix_C
&&
bix_C
==
bih_C
)
&&
(
boh_H
==
box_H
&&
box_H
==
bix_H
&&
bix_H
==
bih_H
)
&&
(
boh_W
==
box_W
&&
box_W
==
bix_W
&&
bix_W
==
bih_W
));
TORCH_CHECK
(
(
toh_N
==
boh_N
)
&&
(
toh_C
==
boh_C
)
&&
(
toh_H
==
boh_H
)
&&
(
toh_W
==
boh_W
));
int
NC
=
toh_C
,
NH
=
toh_H
,
NW
=
toh_W
;
if
(
diagnostics
)
printf
(
"NC=%d, NH=%d, NW=%d
\n
"
,
NC
,
NH
,
NW
);
int
toh_stride_N
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
;
tensor_strides
(
top_out_halo
,
explicit_nhwc
,
toh_stride_N
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
);
int
tox_stride_N
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
;
tensor_strides
(
top_out_tx
,
explicit_nhwc
,
tox_stride_N
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
);
int
tix_stride_N
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
;
tensor_strides
(
top_inp_tx
,
explicit_nhwc
,
tix_stride_N
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
);
int
tih_stride_N
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
;
tensor_strides
(
top_inp_halo
,
explicit_nhwc
,
tih_stride_N
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
);
int
boh_stride_N
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
;
tensor_strides
(
btm_out_halo
,
explicit_nhwc
,
boh_stride_N
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
);
int
box_stride_N
,
box_stride_C
,
box_stride_H
,
box_stride_W
;
tensor_strides
(
btm_out_tx
,
explicit_nhwc
,
box_stride_N
,
box_stride_C
,
box_stride_H
,
box_stride_W
);
int
bix_stride_N
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
;
tensor_strides
(
btm_inp_tx
,
explicit_nhwc
,
bix_stride_N
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
);
int
bih_stride_N
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
;
tensor_strides
(
btm_inp_halo
,
explicit_nhwc
,
bih_stride_N
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
);
// determine if nhwc
auto
is_nhwc
=
(
toh_stride_C
==
1
)
?
true
:
false
;
if
(
diagnostics
)
printf
(
"is_nhwc = %s
\n
"
,
is_nhwc
?
"true"
:
"false"
);
// figure out launch parameters
int
device
;
cudaGetDevice
(
&
device
);
cudaDeviceProp
prop
;
cudaGetDeviceProperties
(
&
prop
,
device
);
assert
(
numSM
>
0
&&
numSM
<=
prop
.
multiProcessorCount
);
auto
current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
int
numThreads
=
128
;
dim3
block
(
numThreads
,
1
,
1
);
AT_DISPATCH_ALL_TYPES_AND
(
at
::
ScalarType
::
Half
,
top_out_halo
.
scalar_type
(),
"push_pull_halos_1d_kernel"
,
[
&
]{
if
(
diagnostics
)
printf
(
"size(scalar_t) = %ld
\n
"
,
sizeof
(
scalar_t
));
scalar_t
*
toh_p
=
top_out_halo
.
data_ptr
<
scalar_t
>
();
scalar_t
*
tox_p
=
top_out_tx
.
data_ptr
<
scalar_t
>
();
scalar_t
*
tix_p
=
top_inp_tx
.
data_ptr
<
scalar_t
>
();
scalar_t
*
tih_p
=
top_inp_halo
.
data_ptr
<
scalar_t
>
();
scalar_t
*
boh_p
=
btm_out_halo
.
data_ptr
<
scalar_t
>
();
scalar_t
*
box_p
=
btm_out_tx
.
data_ptr
<
scalar_t
>
();
scalar_t
*
bix_p
=
btm_inp_tx
.
data_ptr
<
scalar_t
>
();
scalar_t
*
bih_p
=
btm_inp_halo
.
data_ptr
<
scalar_t
>
();
if
(
diagnostics
)
printf
(
"waypoint1
\n
"
);
int
*
top_signal_p
=
top_signal
.
data_ptr
<
int
>
()
+
4
;
int
*
btm_signal_p
=
btm_signal
.
data_ptr
<
int
>
();
int
*
top_wait_p
=
waits
.
data_ptr
<
int
>
();
int
*
btm_wait_p
=
waits
.
data_ptr
<
int
>
()
+
4
;
if
(
diagnostics
)
printf
(
"waypoint2
\n
"
);
// do int4 vector loads if channel count permits
int
elem_size_in_bytes
=
toh_C
*
sizeof
(
scalar_t
);
int
elem_size_in_int4
=
(
elem_size_in_bytes
/
16
);
if
(
diagnostics
)
printf
(
"elem_size_in_bytes = %d, elem_size_in_int4 = %d
\n
"
,
elem_size_in_bytes
,
elem_size_in_int4
);
if
(
is_nhwc
&&
elem_size_in_int4
*
16
==
elem_size_in_bytes
)
{
// can do int4 transfers
int
divisor
=
toh_C
/
elem_size_in_int4
;
if
(
diagnostics
)
printf
(
"CAN DO INT4 :: divisor = %d
\n
"
,
divisor
);
toh_stride_N
/=
divisor
;
toh_stride_H
/=
divisor
;
toh_stride_W
/=
divisor
;
tox_stride_N
/=
divisor
;
tox_stride_H
/=
divisor
;
tox_stride_W
/=
divisor
;
tix_stride_N
/=
divisor
;
tix_stride_H
/=
divisor
;
tix_stride_W
/=
divisor
;
tih_stride_N
/=
divisor
;
tih_stride_H
/=
divisor
;
tih_stride_W
/=
divisor
;
boh_stride_N
/=
divisor
;
boh_stride_H
/=
divisor
;
boh_stride_W
/=
divisor
;
box_stride_N
/=
divisor
;
box_stride_H
/=
divisor
;
box_stride_W
/=
divisor
;
bix_stride_N
/=
divisor
;
bix_stride_H
/=
divisor
;
bix_stride_W
/=
divisor
;
bih_stride_N
/=
divisor
;
bih_stride_H
/=
divisor
;
bih_stride_W
/=
divisor
;
NC
/=
divisor
;
if
(
diagnostics
)
{
printf
(
"divisor=%d
\n
"
,
divisor
);
printf
(
"toh_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
toh_stride_N
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
);
printf
(
"tox_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
tox_stride_N
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
);
printf
(
"tix_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
tix_stride_N
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
);
printf
(
"tih_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
tih_stride_N
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
);
printf
(
"boh_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
boh_stride_N
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
);
printf
(
"box_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
box_stride_N
,
box_stride_C
,
box_stride_H
,
box_stride_W
);
printf
(
"bix_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
bix_stride_N
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
);
printf
(
"bih_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
bih_stride_N
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
);
printf
(
"NC=%d, NH=%d, NW=%d
\n
"
,
NC
,
NH
,
NW
);
}
void
*
kernelArgs
[]
=
{
(
int4
**
)
&
toh_p
,
&
toh_stride_C
,
&
toh_stride_H
,
&
toh_stride_W
,
(
int4
**
)
&
tox_p
,
&
tox_stride_C
,
&
tox_stride_H
,
&
tox_stride_W
,
(
int4
**
)
&
tix_p
,
&
tix_stride_C
,
&
tix_stride_H
,
&
tix_stride_W
,
(
int4
**
)
&
tih_p
,
&
tih_stride_C
,
&
tih_stride_H
,
&
tih_stride_W
,
(
int4
**
)
&
boh_p
,
&
boh_stride_C
,
&
boh_stride_H
,
&
boh_stride_W
,
(
int4
**
)
&
box_p
,
&
box_stride_C
,
&
box_stride_H
,
&
box_stride_W
,
(
int4
**
)
&
bix_p
,
&
bix_stride_C
,
&
bix_stride_H
,
&
bix_stride_W
,
(
int4
**
)
&
bih_p
,
&
bih_stride_C
,
&
bih_stride_H
,
&
bih_stride_W
,
&
NC
,
&
NH
,
&
NW
,
&
top_signal_p
,
&
btm_signal_p
,
&
top_wait_p
,
&
btm_wait_p
};
int
numBlocksPerSm
;
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
int4
,
true
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
}
else
{
// cannot do int4 transfers
if
(
diagnostics
)
printf
(
"CAN NOT DO INT4
\n
"
);
void
*
kernelArgs
[]
=
{
&
toh_p
,
&
toh_stride_C
,
&
toh_stride_H
,
&
toh_stride_W
,
&
tox_p
,
&
tox_stride_C
,
&
tox_stride_H
,
&
tox_stride_W
,
&
tix_p
,
&
tix_stride_C
,
&
tix_stride_H
,
&
tix_stride_W
,
&
tih_p
,
&
tih_stride_C
,
&
tih_stride_H
,
&
tih_stride_W
,
&
boh_p
,
&
boh_stride_C
,
&
boh_stride_H
,
&
boh_stride_W
,
&
box_p
,
&
box_stride_C
,
&
box_stride_H
,
&
box_stride_W
,
&
bix_p
,
&
bix_stride_C
,
&
bix_stride_H
,
&
bix_stride_W
,
&
bih_p
,
&
bih_stride_C
,
&
bih_stride_H
,
&
bih_stride_W
,
&
NC
,
&
NH
,
&
NW
,
&
top_signal_p
,
&
btm_signal_p
,
&
top_wait_p
,
&
btm_wait_p
};
int
numBlocksPerSm
;
if
(
is_nhwc
)
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
true
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
}
else
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
}
}
}
);
}
}
}
}
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
0 → 100644
View file @
795a5e5b
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <torch/extension.h>
#ifndef _peer_memory_h_
#define _peer_memory_h_
namespace
apex
{
namespace
contrib
{
namespace
peer_memory
{
int64_t
allocate_raw
(
int64_t
size
);
void
free_raw
(
int64_t
raw
);
void
zero
(
int64_t
raw
,
int64_t
size
);
at
::
Tensor
get_raw_ipc_address
(
int64_t
raw
);
std
::
vector
<
int64_t
>
get_raw_peers
(
at
::
Tensor
ipc_addresses
,
int
peer_rank
,
int64_t
raw
);
at
::
Tensor
blob_view_half
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
);
at
::
Tensor
blob_view_float
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
);
at
::
Tensor
blob_view_int
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
);
void
push_pull_halos_1d
(
bool
diagnostics
,
bool
explicit_nhwc
,
int
numSM
,
// number of SMs to use
at
::
Tensor
top_out_halo
,
// top output halo in sender device memory
at
::
Tensor
top_out_tx
,
// top output transfer buffer in sender peer pool memory
at
::
Tensor
top_inp_tx
,
// top input transfer buffer in top neighbor peer pool memory
at
::
Tensor
top_inp_halo
,
// top input halo in receiver device memory
at
::
Tensor
btm_out_halo
,
// btm output halo in sender device memory
at
::
Tensor
btm_out_tx
,
// btm output transfer buffer in sender peer pool memory
at
::
Tensor
btm_inp_tx
,
// btm input transfer buffer in btm neighbor peer pool memory
at
::
Tensor
btm_inp_halo
,
// btm input halo in receiver device memory
at
::
Tensor
top_signal
,
// top input signal in receiver device memory
at
::
Tensor
btm_signal
,
// btm input signal in receiver device memory
at
::
Tensor
waits
// top and btm signals for this rank
);
}
}
}
#endif
apex/contrib/csrc/transducer/transducer_joint_kernel.cu
View file @
795a5e5b
...
...
@@ -5,7 +5,7 @@
#include <torch/extension.h>
#include <ATen/AccumulateType.h>
#if
!
def
ined(NEW
_GENERATOR_PATH
)
#ifdef
OLD
_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
...
...
apex/contrib/fmha/fmha.py
View file @
795a5e5b
...
...
@@ -32,16 +32,18 @@ import fmhalib as mha
class
FMHAFun
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
p_dropout
,
max_s
,
is_training
):
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
p_dropout
,
max_s
,
is_training
,
zero_tensors
):
batch_size
=
cu_seqlens
.
numel
()
-
1
if
batch_size
<
4
:
context
,
S_dmask
=
mha
.
fwd_nl
(
qkv
,
cu_seqlens
,
p_dropout
,
max_s
,
is_training
,
None
)
max_s
=
512
context
,
S_dmask
=
mha
.
fwd_nl
(
qkv
,
cu_seqlens
,
p_dropout
,
max_s
,
is_training
,
True
,
zero_tensors
,
None
)
else
:
context
,
S_dmask
=
mha
.
fwd
(
qkv
,
cu_seqlens
,
p_dropout
,
max_s
,
is_training
,
None
)
context
,
S_dmask
=
mha
.
fwd
(
qkv
,
cu_seqlens
,
p_dropout
,
max_s
,
is_training
,
False
,
zero_tensors
,
None
)
ctx
.
save_for_backward
(
qkv
,
S_dmask
)
ctx
.
cu_seqlens
=
cu_seqlens
ctx
.
p_dropout
=
p_dropout
ctx
.
max_s
=
max_s
ctx
.
zero_tensors
=
zero_tensors
return
context
@
staticmethod
...
...
@@ -49,11 +51,11 @@ class FMHAFun(torch.autograd.Function):
qkv
,
S_dmask
=
ctx
.
saved_tensors
batch_size
=
ctx
.
cu_seqlens
.
numel
()
-
1
if
batch_size
<
4
:
dqkv
,
dp
,
_
=
mha
.
bwd_nl
(
dout
,
qkv
,
S_dmask
,
ctx
.
cu_seqlens
,
ctx
.
p_dropout
,
ctx
.
max_s
)
dqkv
,
dp
,
_
=
mha
.
bwd_nl
(
dout
,
qkv
,
S_dmask
,
ctx
.
cu_seqlens
,
ctx
.
p_dropout
,
ctx
.
max_s
,
ctx
.
zero_tensors
)
else
:
dqkv
,
dp
=
mha
.
bwd
(
dout
,
qkv
,
S_dmask
,
ctx
.
cu_seqlens
,
ctx
.
p_dropout
,
ctx
.
max_s
)
dqkv
,
dp
=
mha
.
bwd
(
dout
,
qkv
,
S_dmask
,
ctx
.
cu_seqlens
,
ctx
.
p_dropout
,
ctx
.
max_s
,
ctx
.
zero_tensors
)
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
return
dqkv
,
None
,
None
,
None
,
None
,
None
class
FMHA
(
torch
.
nn
.
Module
):
...
...
@@ -67,8 +69,8 @@ class FMHA(torch.nn.Module):
self
.
d
=
self
.
hidden_size
//
self
.
h
assert
self
.
d
*
self
.
h
==
self
.
hidden_size
,
"Invalid hidden size/num_heads"
def
forward
(
self
,
qkv
,
cu_seqlens
,
max_s
,
is_training
=
True
):
def
forward
(
self
,
qkv
,
cu_seqlens
,
max_s
,
is_training
=
True
,
zero_tensors
=
False
):
ctx
=
FMHAFun
.
apply
(
qkv
.
view
(
-
1
,
3
,
self
.
h
,
self
.
d
),
cu_seqlens
,
self
.
p_dropout
,
max_s
,
is_training
)
ctx
=
FMHAFun
.
apply
(
qkv
.
view
(
-
1
,
3
,
self
.
h
,
self
.
d
),
cu_seqlens
,
self
.
p_dropout
,
max_s
,
is_training
,
zero_tensors
)
return
ctx
.
view
(
-
1
,
self
.
hidden_size
)
apex/contrib/focal_loss/__init__.py
0 → 100644
View file @
795a5e5b
try
:
import
torch
import
focal_loss_cuda
from
.focal_loss
import
focal_loss
del
torch
del
focal_loss_cuda
del
focal_loss
except
ImportError
as
err
:
print
(
"apex was installed without --focal_loss flag, apex.contrib.focal_loss is not available"
)
apex/contrib/focal_loss/focal_loss.py
0 → 100644
View file @
795a5e5b
import
torch
import
focal_loss_cuda
class
FocalLoss
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
cls_output
,
cls_targets_at_level
,
num_positives_sum
,
num_real_classes
,
alpha
,
gamma
,
label_smoothing
=
0.0
,
):
loss
,
partial_grad
=
focal_loss_cuda
.
forward
(
cls_output
,
cls_targets_at_level
,
num_positives_sum
,
num_real_classes
,
alpha
,
gamma
,
label_smoothing
,
)
ctx
.
save_for_backward
(
partial_grad
,
num_positives_sum
)
return
loss
@
staticmethod
def
backward
(
ctx
,
grad_loss
):
partial_grad
,
num_positives_sum
=
ctx
.
saved_tensors
# The backward kernel is actually in-place to save memory space,
# partial_grad and grad_input are the same tensor.
grad_input
=
focal_loss_cuda
.
backward
(
grad_loss
,
partial_grad
,
num_positives_sum
)
return
grad_input
,
None
,
None
,
None
,
None
,
None
,
None
def
focal_loss
(
cls_output
:
torch
.
Tensor
,
cls_targets_at_level
:
torch
.
Tensor
,
num_positive_sum
:
torch
.
Tensor
,
num_real_classes
:
int
,
alpha
:
float
,
gamma
:
float
,
label_smoothing
:
float
=
0.0
,
)
->
torch
.
Tensor
:
"""Fused focal loss function."""
return
FocalLoss
.
apply
(
cls_output
,
cls_targets_at_level
,
num_positive_sum
,
num_real_classes
,
alpha
,
gamma
,
label_smoothing
,
)
Prev
1
2
3
4
5
6
7
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment