Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
96850dfa
Unverified
Commit
96850dfa
authored
Aug 15, 2022
by
Jithun Nair
Committed by
GitHub
Aug 15, 2022
Browse files
Merge pull request #80 from ROCmSoftwarePlatform/IFU-master-2022-07-29
IFU-master-2022-07-29
parents
87fc4125
cc5f83b5
Changes
235
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1094 additions
and
41 deletions
+1094
-41
apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu
apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu
+6
-0
apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu
apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu
+97
-1
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu
...rc/multihead_attn/additive_masked_softmax_dropout_cuda.cu
+1
-1
apex/contrib/csrc/multihead_attn/dropout.cuh
apex/contrib/csrc/multihead_attn/dropout.cuh
+1
-1
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
...src/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
+1
-2
apex/contrib/csrc/multihead_attn/layer_norm.cuh
apex/contrib/csrc/multihead_attn/layer_norm.cuh
+5
-5
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
...ihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
+7
-13
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
...trib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
+1
-2
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
.../csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
+1
-2
apex/contrib/csrc/multihead_attn/softmax.cuh
apex/contrib/csrc/multihead_attn/softmax.cuh
+7
-5
apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp
apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp
+25
-0
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
+211
-0
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh
+45
-0
apex/contrib/csrc/peer_memory/peer_memory.cpp
apex/contrib/csrc/peer_memory/peer_memory.cpp
+29
-0
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
+529
-0
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
+48
-0
apex/contrib/csrc/transducer/transducer_joint_kernel.cu
apex/contrib/csrc/transducer/transducer_joint_kernel.cu
+1
-1
apex/contrib/fmha/fmha.py
apex/contrib/fmha/fmha.py
+10
-8
apex/contrib/focal_loss/__init__.py
apex/contrib/focal_loss/__init__.py
+9
-0
apex/contrib/focal_loss/focal_loss.py
apex/contrib/focal_loss/focal_loss.py
+60
-0
No files found.
apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu
View file @
96850dfa
...
@@ -166,6 +166,12 @@ REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
...
@@ -166,6 +166,12 @@ REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER
(
12800
,
bf16
,
bf16
,
bf16
,
fp32
,
5
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
12800
,
bf16
,
bf16
,
bf16
,
fp32
,
5
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
12800
,
bf16
,
fp32
,
bf16
,
fp32
,
5
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
12800
,
bf16
,
fp32
,
bf16
,
fp32
,
5
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
14336
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
14336
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
14336
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
14336
,
bf16
,
bf16
,
bf16
,
fp32
,
4
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
14336
,
bf16
,
fp32
,
bf16
,
fp32
,
4
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
15360
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
15360
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
15360
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
4
,
4
,
4
);
REGISTER_BWD_LAUNCHER
(
15360
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
4
,
4
,
4
);
REGISTER_BWD_LAUNCHER
(
15360
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
15360
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
4
,
8
,
4
);
...
...
apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu
View file @
96850dfa
...
@@ -67,9 +67,105 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params
...
@@ -67,9 +67,105 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params
}
}
REGISTER_FWD_LAUNCHER
(
768
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
768
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
768
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
768
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
768
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2304
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2304
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2304
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2304
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2304
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3840
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
3840
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
3840
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
3840
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
3840
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
4096
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
6144
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
6144
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
6144
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
6144
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
6144
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
10240
,
fp32
,
fp32
,
fp32
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
10240
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
10240
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
10240
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
10240
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
12288
,
fp32
,
fp32
,
fp32
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
12288
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
12288
,
fp16
,
fp32
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
12288
,
bf16
,
bf16
,
bf16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
12288
,
bf16
,
fp32
,
bf16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
12800
,
fp32
,
fp32
,
fp32
,
fp32
,
2
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
12800
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
12800
,
fp16
,
fp32
,
fp16
,
fp32
,
2
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
12800
,
bf16
,
bf16
,
bf16
,
fp32
,
2
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
12800
,
bf16
,
fp32
,
bf16
,
fp32
,
2
,
1
,
4
,
4
);
REGISTER_FWD_LAUNCHER
(
14336
,
fp32
,
fp32
,
fp32
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
14336
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
14336
,
fp16
,
fp32
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
14336
,
bf16
,
bf16
,
bf16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
14336
,
bf16
,
fp32
,
bf16
,
fp32
,
2
,
1
,
4
,
8
);
REGISTER_FWD_LAUNCHER
(
15360
,
fp32
,
fp32
,
fp32
,
fp32
,
2
,
1
,
4
,
8
);
REGISTER_FWD_LAUNCHER
(
15360
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
8
);
REGISTER_FWD_LAUNCHER
(
15360
,
fp16
,
fp32
,
fp16
,
fp32
,
2
,
1
,
4
,
8
);
REGISTER_FWD_LAUNCHER
(
15360
,
bf16
,
bf16
,
bf16
,
fp32
,
2
,
1
,
4
,
8
);
REGISTER_FWD_LAUNCHER
(
15360
,
bf16
,
fp32
,
bf16
,
fp32
,
2
,
1
,
4
,
8
);
REGISTER_FWD_LAUNCHER
(
16384
,
fp32
,
fp32
,
fp32
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
16384
,
fp32
,
fp32
,
fp32
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
16384
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
16384
,
fp16
,
fp16
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
16384
,
fp16
,
fp32
,
fp16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
16384
,
bf16
,
bf16
,
bf16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
16384
,
bf16
,
bf16
,
bf16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
16384
,
bf16
,
fp32
,
bf16
,
fp32
,
2
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
16384
,
bf16
,
fp32
,
bf16
,
fp32
,
2
,
1
,
4
,
16
);
...
...
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu
View file @
96850dfa
...
@@ -110,4 +110,4 @@ torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
...
@@ -110,4 +110,4 @@ torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
}
}
}
// namespace additive_mask_softmax_dropout
}
// namespace additive_mask_softmax_dropout
}
// namespace fused_softmax
}
// namespace fused_softmax
}
// namespace multihead_attn
}
// namespace multihead_attn
\ No newline at end of file
apex/contrib/csrc/multihead_attn/dropout.cuh
View file @
96850dfa
#pragma once
#pragma once
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#if
!
def
ined(NEW
_GENERATOR_PATH
)
#ifdef
OLD
_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#include <ATen/CUDAGeneratorImpl.h>
#else
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAGeneratorImpl.h>
...
...
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
View file @
96850dfa
...
@@ -687,5 +687,4 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -687,5 +687,4 @@ std::vector<torch::Tensor> bwd_cuda(
}
// end namespace rocblas_gemmex
}
// end namespace rocblas_gemmex
}
// end namespace encdec_norm_add
}
// end namespace encdec_norm_add
}
// end namespace multihead_attn
}
// end namespace multihead_attn
\ No newline at end of file
apex/contrib/csrc/multihead_attn/layer_norm.cuh
View file @
96850dfa
...
@@ -67,7 +67,7 @@ __device__ void cuWelfordMuSigma2(const T *__restrict__ vals, const int n1,
...
@@ -67,7 +67,7 @@ __device__ void cuWelfordMuSigma2(const T *__restrict__ vals, const int n1,
}
}
// intra-warp reductions
// intra-warp reductions
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
U
muB
=
WARP_SHFL
(
mu
,
srcLaneB
,
32
);
U
muB
=
WARP_SHFL
(
mu
,
srcLaneB
,
32
);
U
countB
=
WARP_SHFL
(
count
,
srcLaneB
,
32
);
U
countB
=
WARP_SHFL
(
count
,
srcLaneB
,
32
);
U
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
,
32
);
U
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
,
32
);
...
@@ -108,7 +108,7 @@ __device__ void cuWelfordMuSigma2(const T *__restrict__ vals, const int n1,
...
@@ -108,7 +108,7 @@ __device__ void cuWelfordMuSigma2(const T *__restrict__ vals, const int n1,
// don't care about final value of count, we know count == n2
// don't care about final value of count, we know count == n2
}
else
{
}
else
{
mu
=
WARP_SHFL
(
mu
,
0
,
32
);
mu
=
WARP_SHFL
(
mu
,
0
,
32
);
sigma2
=
WARP_SHFL
(
sigma2
/
U
(
n2
),
0
,
32
);
sigma2
=
WARP_SHFL
(
sigma2
/
U
(
n2
),
0
,
32
);
}
}
}
}
}
}
...
@@ -158,7 +158,7 @@ __device__ void cuWelfordMuSigma2(const at::Half *__restrict__ vals,
...
@@ -158,7 +158,7 @@ __device__ void cuWelfordMuSigma2(const at::Half *__restrict__ vals,
}
}
// intra-warp reductions
// intra-warp reductions
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
float
muB
=
WARP_SHFL
(
mu
,
srcLaneB
,
32
);
float
muB
=
WARP_SHFL
(
mu
,
srcLaneB
,
32
);
float
countB
=
WARP_SHFL
(
count
,
srcLaneB
,
32
);
float
countB
=
WARP_SHFL
(
count
,
srcLaneB
,
32
);
float
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
,
32
);
float
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
,
32
);
...
@@ -199,7 +199,7 @@ __device__ void cuWelfordMuSigma2(const at::Half *__restrict__ vals,
...
@@ -199,7 +199,7 @@ __device__ void cuWelfordMuSigma2(const at::Half *__restrict__ vals,
// don't care about final value of count, we know count == n2
// don't care about final value of count, we know count == n2
}
else
{
}
else
{
mu
=
WARP_SHFL
(
mu
,
0
,
32
);
mu
=
WARP_SHFL
(
mu
,
0
,
32
);
sigma2
=
WARP_SHFL
(
sigma2
/
float
(
n2
),
0
,
32
);
sigma2
=
WARP_SHFL
(
sigma2
/
float
(
n2
),
0
,
32
);
}
}
}
}
}
}
...
@@ -521,7 +521,7 @@ cuComputeGradInput(const T *__restrict__ dout, const T *__restrict__ dout_resid,
...
@@ -521,7 +521,7 @@ cuComputeGradInput(const T *__restrict__ dout, const T *__restrict__ dout_resid,
}
}
}
}
// intra-warp reductions
// intra-warp reductions
for
(
int
mask
=
blockDim
.
x
/
2
;
mask
>
0
;
mask
/=
2
)
{
for
(
int
mask
=
blockDim
.
x
/
2
;
mask
>
0
;
mask
/=
2
)
{
sum_loss1
+=
WARP_SHFL_XOR
(
sum_loss1
,
mask
,
32
);
sum_loss1
+=
WARP_SHFL_XOR
(
sum_loss1
,
mask
,
32
);
sum_loss2
+=
WARP_SHFL_XOR
(
sum_loss2
,
mask
,
32
);
sum_loss2
+=
WARP_SHFL_XOR
(
sum_loss2
,
mask
,
32
);
}
}
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
View file @
96850dfa
...
@@ -19,19 +19,13 @@ namespace multihead_attn {
...
@@ -19,19 +19,13 @@ namespace multihead_attn {
namespace
self_bias_additive_mask
{
namespace
self_bias_additive_mask
{
namespace
rocblas_gemmex
{
namespace
rocblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
is_training
,
bool
use_time_mask
,
int
heads
,
torch
::
Tensor
const
&
inputs
,
bool
is_training
,
torch
::
Tensor
const
&
input_weights
,
int
heads
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_biases
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_biases
,
torch
::
Tensor
const
&
output_weights
,
const
half
*
pad_mask
,
float
dropout_prob
)
{
torch
::
Tensor
const
&
input_biases
,
torch
::
Tensor
const
&
output_biases
,
const
half
*
pad_mask
,
float
dropout_prob
)
{
const
int
embed_dim
=
inputs
.
size
(
2
);
const
int
embed_dim
=
inputs
.
size
(
2
);
const
int
sequences
=
inputs
.
size
(
1
);
const
int
sequences
=
inputs
.
size
(
1
);
const
int
q_seq_len
=
inputs
.
size
(
0
);
const
int
q_seq_len
=
inputs
.
size
(
0
);
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
View file @
96850dfa
...
@@ -501,5 +501,4 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -501,5 +501,4 @@ std::vector<torch::Tensor> bwd_cuda(
}
// end namespace rocblas_gemmex
}
// end namespace rocblas_gemmex
}
// end namespace self
}
// end namespace self
}
// end namespace multihead_attn
}
// end namespace multihead_attn
\ No newline at end of file
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
View file @
96850dfa
...
@@ -577,5 +577,4 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -577,5 +577,4 @@ std::vector<torch::Tensor> bwd_cuda(
}
// end namespace rocblas_gemmex
}
// end namespace rocblas_gemmex
}
// end namespace self_norm_add
}
// end namespace self_norm_add
}
// end namespace multihead_attn
}
// end namespace multihead_attn
\ No newline at end of file
apex/contrib/csrc/multihead_attn/softmax.cuh
View file @
96850dfa
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <curand_kernel.h>
#include <curand_kernel.h>
#if
!
def
ined(NEW
_GENERATOR_PATH
)
#ifdef
OLD
_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#include <ATen/CUDAGeneratorImpl.h>
#else
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAGeneratorImpl.h>
...
@@ -1593,11 +1593,13 @@ int log2_ceil_native(int value) {
...
@@ -1593,11 +1593,13 @@ int log2_ceil_native(int value) {
}
}
template
<
typename
T
>
template
<
typename
T
>
__device__
__forceinline__
T
WARP_SHFL_XOR_NATIVE
(
T
value
,
int
laneMask
,
int
width
=
warpSize
,
unsigned
int
mask
=
0xffffffff
)
{
__device__
__forceinline__
T
#if CUDA_VERSION >= 9000 && !defined(__HIP_PLATFORM_HCC__)
WARP_SHFL_XOR_NATIVE
(
T
value
,
int
laneMask
,
int
width
=
warpSize
,
return
__shfl_xor_sync
(
mask
,
value
,
laneMask
,
width
);
unsigned
int
mask
=
0xffffffff
)
{
#if CUDA_VERSION >= 9000
return
__shfl_xor_sync
(
mask
,
value
,
laneMask
,
width
);
#else
#else
return
__shfl_xor
(
value
,
laneMask
,
width
);
return
__shfl_xor
(
value
,
laneMask
,
width
);
#endif
#endif
}
}
...
...
apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp
0 → 100644
View file @
96850dfa
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nccl_p2p_cuda.cuh"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"get_unique_nccl_id"
,
&
apex
::
contrib
::
nccl_p2p
::
get_unique_nccl_id
,
"get_unique_nccl_id"
);
m
.
def
(
"init_nccl_comm"
,
&
apex
::
contrib
::
nccl_p2p
::
init_nccl_comm
,
"init_nccl_comm"
);
m
.
def
(
"left_right_halo_exchange_inplace"
,
&
apex
::
contrib
::
nccl_p2p
::
left_right_halo_exchange_inplace
,
"left_right_halo_exchange_inplace"
);
m
.
def
(
"left_right_halo_exchange"
,
&
apex
::
contrib
::
nccl_p2p
::
left_right_halo_exchange
,
"left_right_halo_exchange"
);
m
.
def
(
"add_delay"
,
&
apex
::
contrib
::
nccl_p2p
::
add_delay
,
"add_delay"
);
}
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu
0 → 100644
View file @
96850dfa
#include <torch/extension.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <ATen/cuda/CUDAContext.h>
#include <list>
#include <cstdio>
#include <ctime>
#include <cassert>
#include "nccl.h"
/*
* This file implements a crude but effective mechanism for copying data between tenors owned by different ranks
* on the same machine using cudaMemcpyAsync peer-to-peer transfers.
*/
namespace
{
__global__
void
AddDelay_kernel
(
const
int
delay
,
int
*
counter
)
{
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
{
// waste time while doing something compiler can't predict, thus preventing it from optimizing away this code.
int
new_counter
=
0
;
double
elapsed
=
0
;
clock_t
start
=
clock
();
do
{
clock_t
now
=
clock
();
elapsed
=
(
double
)(
now
-
start
)
*
1e9
/
CLOCKS_PER_SEC
;
++
new_counter
;
}
while
(
elapsed
<
(
double
)
delay
);
*
counter
=
new_counter
;
}
}
class
NcclCommWrapper
{
private:
ncclComm_t
comm
;
int
rank
,
world_size
;
ncclDataType_t
get_nccl_type
(
at
::
Tensor
input
)
{
switch
(
input
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
return
ncclFloat16
;
case
at
::
ScalarType
::
Float
:
return
ncclFloat32
;
case
at
::
ScalarType
::
Double
:
return
ncclFloat64
;
case
at
::
ScalarType
::
Byte
:
return
ncclUint8
;
case
at
::
ScalarType
::
Char
:
return
ncclInt8
;
case
at
::
ScalarType
::
Int
:
return
ncclInt32
;
case
at
::
ScalarType
::
Long
:
return
ncclInt64
;
case
at
::
ScalarType
::
BFloat16
:
return
ncclBfloat16
;
default:
assert
(
false
);
}
}
public:
NcclCommWrapper
()
{
memset
(
&
comm
,
0
,
sizeof
(
ncclComm_t
));
rank
=
0
;
world_size
=
0
;
}
NcclCommWrapper
(
ncclUniqueId
id
,
int
my_rank
,
int
num_ranks
)
{
ncclCommInitRank
(
&
comm
,
num_ranks
,
id
,
my_rank
);
rank
=
my_rank
;
world_size
=
num_ranks
;
}
~
NcclCommWrapper
()
{
printf
(
"ncclCommDestroy()
\n
"
);
ncclCommDestroy
(
comm
);
}
void
left_right_halo_exchange_inplace
(
int
left_rank
,
int
right_rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
right_input_halo
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
ncclGroupStart
();
ncclDataType_t
ncclType
=
get_nccl_type
(
left_output_halo
);
bool
left_zero
=
(
left_rank
<
0
);
bool
right_zero
=
(
right_rank
<
0
);
size_t
left_n
=
torch
::
numel
(
left_output_halo
);
size_t
right_n
=
torch
::
numel
(
right_output_halo
);
assert
(
left_n
>
0
&&
left_n
==
right_n
);
if
(
left_zero
)
{
left_input_halo
.
zero_
();
}
else
{
AT_DISPATCH_ALL_TYPES_AND3
(
at
::
ScalarType
::
Bool
,
at
::
ScalarType
::
BFloat16
,
at
::
ScalarType
::
Half
,
left_output_halo
.
scalar_type
(),
"left_halo_exch"
,
[
&
]()
{
// send left (to my_rank - 1)
ncclSend
(
left_output_halo
.
data_ptr
<
scalar_t
>
(),
left_n
,
ncclType
,
left_rank
,
comm
,
stream
);
// receive left (from my_rank - 1)
ncclRecv
(
left_input_halo
.
data_ptr
<
scalar_t
>
(),
right_n
,
ncclType
,
left_rank
,
comm
,
stream
);
});
}
if
(
right_zero
)
{
right_input_halo
.
zero_
();
}
else
{
AT_DISPATCH_ALL_TYPES_AND3
(
at
::
ScalarType
::
Bool
,
at
::
ScalarType
::
BFloat16
,
at
::
ScalarType
::
Half
,
right_output_halo
.
scalar_type
(),
"right_halo_exch"
,
[
&
]()
{
// send right (to my_rank + 1 )
ncclSend
(
right_output_halo
.
data_ptr
<
scalar_t
>
(),
right_n
,
ncclType
,
right_rank
,
comm
,
stream
);
// receive right (from my_rank + 1)
ncclRecv
(
right_input_halo
.
data_ptr
<
scalar_t
>
(),
left_n
,
ncclType
,
right_rank
,
comm
,
stream
);
});
}
ncclGroupEnd
();
}
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
int
left_rank
,
int
right_rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
)
{
// after halo exchange:
// left_output_halo of rank+1 ends up in right_input_halo of rank
// right_output_halo of rank-1 ends up in left_input_halo of rank
auto
right_input_halo
=
torch
::
empty_like
(
left_output_halo
);
auto
left_input_halo
=
torch
::
empty_like
(
right_output_halo
);
left_right_halo_exchange_inplace
(
left_rank
,
right_rank
,
left_output_halo
,
right_output_halo
,
left_input_halo
,
right_input_halo
);
return
{
left_input_halo
,
right_input_halo
};
}
};
class
ManagedObjects
{
public:
ManagedObjects
()
{
}
~
ManagedObjects
()
{
for
(
auto
it
=
_nccl_comms
.
begin
();
it
!=
_nccl_comms
.
end
();
++
it
)
{
delete
*
it
;
}
}
int
add_comm
(
NcclCommWrapper
*
comm
)
{
int
handle
=
_nccl_comms
.
size
();
_nccl_comms
.
push_back
(
comm
);
return
handle
;
}
NcclCommWrapper
&
get_comm
(
int
handle
)
{
assert
(
handle
>=
0
&&
handle
<
_nccl_comms
.
size
());
return
*
_nccl_comms
[
handle
];
}
private:
std
::
vector
<
NcclCommWrapper
*>
_nccl_comms
;
};
class
ManagedObjects
mo
;
}
// end anonymous namespace
namespace
apex
{
namespace
contrib
{
namespace
nccl_p2p
{
at
::
Tensor
get_unique_nccl_id
(
int
n
)
{
ncclUniqueId
id
;
ncclGetUniqueId
(
&
id
);
auto
id_tensor
=
torch
::
empty
({
n
,(
int
)
sizeof
(
ncclUniqueId
)},
torch
::
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCPU
).
requires_grad
(
false
));
auto
id_ptr
=
id_tensor
.
data_ptr
<
uint8_t
>
();
size_t
offset
=
0
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
ncclUniqueId
id
;
ncclGetUniqueId
(
&
id
);
memcpy
(
id_ptr
+
offset
,
&
id
,
sizeof
(
ncclUniqueId
));
offset
+=
sizeof
(
ncclUniqueId
);
}
return
id_tensor
;
}
int
init_nccl_comm
(
at
::
Tensor
unique_nccl_id
,
int
my_rank
,
int
num_ranks
)
{
ncclUniqueId
id
;
auto
unique_nccl_id_ptr
=
unique_nccl_id
.
data_ptr
<
uint8_t
>
();
memcpy
(
&
id
,
unique_nccl_id_ptr
,
sizeof
(
ncclUniqueId
));
NcclCommWrapper
*
comm
=
new
NcclCommWrapper
(
id
,
my_rank
,
num_ranks
);
int
handle
=
mo
.
add_comm
(
comm
);
comm
=
0L
;
return
handle
;
}
void
left_right_halo_exchange_inplace
(
int
handle
,
int
left_rank
,
int
right_rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
right_input_halo
)
{
class
NcclCommWrapper
&
communicator
=
mo
.
get_comm
(
handle
);
return
communicator
.
left_right_halo_exchange_inplace
(
left_rank
,
right_rank
,
left_output_halo
,
right_output_halo
,
left_input_halo
,
right_input_halo
);
}
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
int
handle
,
int
left_rank
,
int
right_rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
)
{
class
NcclCommWrapper
&
communicator
=
mo
.
get_comm
(
handle
);
return
communicator
.
left_right_halo_exchange
(
left_rank
,
right_rank
,
left_output_halo
,
right_output_halo
);
}
void
add_delay
(
int
delay
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
t
=
torch
::
empty
({
1
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
AddDelay_kernel
<<<
1
,
1
,
0
,
stream
>>>
(
delay
,
t
.
data_ptr
<
int
>
());
}
}}}
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh
0 → 100644
View file @
96850dfa
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <torch/extension.h>
#ifndef _nccl_p2p_h_
#define _nccl_p2p_h_
namespace
apex
{
namespace
contrib
{
namespace
nccl_p2p
{
at
::
Tensor
get_unique_nccl_id
(
int
n
);
int
init_nccl_comm
(
at
::
Tensor
unique_nccl_id
,
int
my_rank
,
int
num_ranks
);
void
left_right_halo_exchange_inplace
(
int
handle
,
int
left_rank
,
int
right_rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
,
at
::
Tensor
left_input_halo
,
at
::
Tensor
right_input_halo
);
std
::
vector
<
at
::
Tensor
>
left_right_halo_exchange
(
int
handle
,
int
left_rank
,
int
right_rank
,
at
::
Tensor
left_output_halo
,
at
::
Tensor
right_output_halo
);
void
add_delay
(
int
delay
);
}}}
#endif
apex/contrib/csrc/peer_memory/peer_memory.cpp
0 → 100644
View file @
96850dfa
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "peer_memory_cuda.cuh"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"allocate_raw"
,
&
apex
::
contrib
::
peer_memory
::
allocate_raw
,
"allocate_raw"
);
m
.
def
(
"free_raw"
,
&
apex
::
contrib
::
peer_memory
::
free_raw
,
"free_raw"
);
m
.
def
(
"zero"
,
&
apex
::
contrib
::
peer_memory
::
zero
,
"zero"
);
m
.
def
(
"get_raw_ipc_address"
,
&
apex
::
contrib
::
peer_memory
::
get_raw_ipc_address
,
"get_raw_ipc_address"
);
m
.
def
(
"get_raw_peers"
,
&
apex
::
contrib
::
peer_memory
::
get_raw_peers
,
"get_raw_peers"
);
m
.
def
(
"blob_view_half"
,
&
apex
::
contrib
::
peer_memory
::
blob_view_half
,
"blob_view_half"
);
m
.
def
(
"blob_view_float"
,
&
apex
::
contrib
::
peer_memory
::
blob_view_float
,
"blob_view_float"
);
m
.
def
(
"blob_view_int"
,
&
apex
::
contrib
::
peer_memory
::
blob_view_int
,
"blob_view_int"
);
m
.
def
(
"push_pull_halos_1d"
,
&
apex
::
contrib
::
peer_memory
::
push_pull_halos_1d
,
"push_pull_halos_1d"
);
}
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu
0 → 100644
View file @
96850dfa
#include <torch/extension.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <ATen/cuda/CUDAContext.h>
#include <list>
#include <cstdio>
#include <cassert>
#include <cuda_runtime_api.h>
#include <cooperative_groups.h>
#include "nccl.h"
namespace
cg
=
cooperative_groups
;
#define CUDACHECK(cmd) do { \
cudaError_t err = cmd; \
if( err != cudaSuccess ) { \
char hostname[1024]; \
gethostname(hostname, 1024); \
printf("%s: CUDA failure %s:%d '%s'\n", \
hostname, \
__FILE__,__LINE__,cudaGetErrorString(err)); \
} \
} while(0)
namespace
{
/* Basic deleter function for from_blob function.
void deleter(void* ptr)
{
printf("deleter(ptr=%p)\n",ptr);
cudaFree(ptr);
}
*/
template
<
class
T
>
at
::
Tensor
blob_view
(
T
*
raw_ptr
,
std
::
vector
<
int64_t
>
shape
,
const
at
::
TensorOptions
&
options
,
bool
channels_last
)
{
size_t
size
=
1
;
std
::
vector
<
int64_t
>
strides
(
shape
.
size
());
if
(
channels_last
)
{
assert
(
shape
.
size
()
==
4
);
strides
[
0
]
=
shape
[
1
]
*
shape
[
2
]
*
shape
[
3
];
strides
[
1
]
=
1
;
strides
[
2
]
=
shape
[
1
]
*
shape
[
3
];
strides
[
3
]
=
shape
[
1
];
}
else
{
int
idx
=
strides
.
size
();
for
(
auto
it
=
shape
.
rbegin
();
it
!=
shape
.
rend
();
++
it
)
{
strides
[
--
idx
]
=
size
;
size
*=
*
it
;
}
}
size
*=
sizeof
(
T
);
// TODO: Implement dynamic reuse of pooled peer memory.
// We provide no deleter function because all peer memory allocations are static in this implementation.
return
torch
::
from_blob
((
void
*
)
raw_ptr
,
shape
,
strides
,
0L
,
options
);
}
void
tensor_shape
(
at
::
Tensor
t
,
bool
explicit_nhwc
,
int
&
N
,
int
&
C
,
int
&
H
,
int
&
W
)
{
if
(
t
.
dim
()
==
3
)
{
N
=
1
;
if
(
explicit_nhwc
)
{
C
=
t
.
size
(
2
);
H
=
t
.
size
(
0
);
W
=
t
.
size
(
1
);
}
else
{
C
=
t
.
size
(
0
);
H
=
t
.
size
(
1
);
W
=
t
.
size
(
2
);
}
}
else
if
(
t
.
dim
()
==
4
)
{
if
(
explicit_nhwc
)
{
N
=
t
.
size
(
0
);
C
=
t
.
size
(
3
);
H
=
t
.
size
(
1
);
W
=
t
.
size
(
2
);
}
else
{
N
=
t
.
size
(
0
);
C
=
t
.
size
(
1
);
H
=
t
.
size
(
2
);
W
=
t
.
size
(
3
);
}
}
else
{
printf
(
"%s;%d - t.dim() must be either 3 or 4 (was %d)
\n
"
,
__FILE__
,
__LINE__
,
t
.
dim
());
assert
(
t
.
dim
()
==
3
||
t
.
dim
()
==
4
);
}
}
void
tensor_strides
(
at
::
Tensor
t
,
bool
explicit_nhwc
,
int
&
stride_N
,
int
&
stride_C
,
int
&
stride_H
,
int
&
stride_W
)
{
if
(
t
.
dim
()
==
3
)
{
if
(
explicit_nhwc
)
{
stride_C
=
t
.
stride
(
2
);
stride_H
=
t
.
stride
(
0
);
stride_W
=
t
.
stride
(
1
);
}
else
{
stride_C
=
t
.
stride
(
0
);
stride_H
=
t
.
stride
(
1
);
stride_W
=
t
.
stride
(
2
);
}
stride_N
=
t
.
size
(
0
)
*
t
.
size
(
1
)
*
t
.
size
(
2
);
}
else
if
(
t
.
dim
()
==
4
)
{
if
(
explicit_nhwc
)
{
stride_N
=
t
.
stride
(
0
);
stride_C
=
t
.
stride
(
3
);
stride_H
=
t
.
stride
(
1
);
stride_W
=
t
.
stride
(
2
);
}
else
{
stride_N
=
t
.
stride
(
0
);
stride_C
=
t
.
stride
(
1
);
stride_H
=
t
.
stride
(
2
);
stride_W
=
t
.
stride
(
3
);
}
}
else
{
printf
(
"%s;%d - t.dim() must be either 3 or 4 (was %d)
\n
"
,
__FILE__
,
__LINE__
,
t
.
dim
());
assert
(
t
.
dim
()
==
3
||
t
.
dim
()
==
4
);
}
}
template
<
class
T
,
bool
is_HWC
>
__device__
void
strided_copy_kernel
(
T
*
dst
,
const
int
dst_stride_C
,
const
int
dst_stride_H
,
const
int
dst_stride_W
,
const
T
*
src
,
const
int
src_stride_C
,
const
int
src_stride_H
,
const
int
src_stride_W
,
const
int
NC
,
const
int
NH
,
const
int
NW
)
{
size_t
tot_num_threads
=
gridDim
.
x
*
blockDim
.
x
;
size_t
thread_id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
count
=
NC
*
NH
*
NW
;
for
(
size_t
i
=
thread_id
;
i
<
count
;
i
+=
tot_num_threads
)
{
size_t
c
,
h
,
w
;
if
(
is_HWC
)
{
c
=
i
%
NC
;
w
=
i
/
NC
;
h
=
w
/
NW
;
w
=
w
%
NW
;
}
else
{
w
=
i
%
NW
;
h
=
i
/
NW
;
c
=
h
/
NH
;
h
=
h
%
NH
;
}
size_t
dst_off
=
c
*
dst_stride_C
+
h
*
dst_stride_H
+
w
*
dst_stride_W
;
size_t
src_off
=
c
*
src_stride_C
+
h
*
src_stride_H
+
w
*
src_stride_W
;
dst
[
dst_off
]
=
src
[
src_off
];
}
}
__device__
void
checked_signal
(
volatile
int
*
signal1_flag
,
volatile
int
*
signal2_flag
,
const
int
v1
,
const
int
v2
,
const
int
v3
,
const
int
v4
)
{
cg
::
this_grid
().
sync
();
bool
is_main_thread
=
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
?
true
:
false
;
if
(
is_main_thread
)
{
// flush all writes to global memory
__threadfence_system
();
// wait for top or bottom neighbor to clear signal
register
int
r1
,
r2
,
r3
,
r4
;
bool
top_zeroed
=
false
,
btm_zeroed
=
false
,
top_done
=
false
,
btm_done
=
false
;
do
{
do
{
if
(
!
top_zeroed
)
{
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal1_flag
)
:
"memory"
);
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
top_zeroed
=
true
;
}
if
(
!
btm_zeroed
)
{
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
signal2_flag
)
:
"memory"
);
if
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
)
btm_zeroed
=
true
;
}
}
while
((
top_zeroed
==
top_done
)
&&
(
btm_zeroed
==
btm_done
));
if
(
!
top_done
&&
top_zeroed
)
{
// signal to top neighbor my output is ready
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal1_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
top_done
=
true
;
}
if
(
!
btm_done
&&
btm_zeroed
)
{
// signal to bottom neighbor my output is ready
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
signal2_flag
),
"r"
(
v1
),
"r"
(
v2
),
"r"
(
v3
),
"r"
(
v4
)
:
"memory"
);
btm_done
=
true
;
}
}
while
(
!
top_done
||
!
btm_done
);
}
}
__device__
void
wait_for
(
volatile
int
*
wait_flag
,
const
int
v1
,
const
int
v2
,
const
int
v3
,
const
int
v4
)
{
bool
is_main_thread
=
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
?
true
:
false
;
if
(
is_main_thread
)
{
register
int
r1
,
r2
,
r3
,
r4
;
// wait for senders to signal their output is read
do
{
asm
volatile
(
"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
:
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
)
:
"l"
(
wait_flag
)
:
"memory"
);
}
while
(
r1
!=
v1
||
r2
!=
v2
||
r3
!=
v3
||
r4
!=
v4
);
}
cg
::
this_grid
().
sync
();
// all threads wait for main
}
__device__
void
clear_flag
(
volatile
int
*
wait_flag
)
{
cg
::
this_grid
().
sync
();
// wait for all threads in kernel to finish
bool
is_main_thread
=
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
?
true
:
false
;
if
(
is_main_thread
)
{
register
int
r1
,
r2
,
r3
,
r4
;
r1
=
0
;
r2
=
0
;
r3
=
0
;
r4
=
0
;
asm
volatile
(
"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
::
"l"
(
wait_flag
),
"r"
(
r1
),
"r"
(
r2
),
"r"
(
r3
),
"r"
(
r4
)
:
"memory"
);
}
}
template
<
class
T
,
bool
is_HWC
>
#if __CUDA_ARCH__ >= 700
__launch_bounds__
(
128
,
16
)
#endif
__global__
void
push_pull_halos_1d_kernel
(
// top halo,
const
T
*
toh
,
int
toh_stride_C
,
int
toh_stride_H
,
int
toh_stride_W
,
// top output halo
T
*
tox
,
int
tox_stride_C
,
int
tox_stride_H
,
int
tox_stride_W
,
// top output tx buffer
T
*
tix
,
int
tix_stride_C
,
int
tix_stride_H
,
int
tix_stride_W
,
// top input tx buffer
T
*
tih
,
int
tih_stride_C
,
int
tih_stride_H
,
int
tih_stride_W
,
// top input halo
// btm halo
const
T
*
boh
,
int
boh_stride_C
,
int
boh_stride_H
,
int
boh_stride_W
,
// btm output halo
T
*
box
,
int
box_stride_C
,
int
box_stride_H
,
int
box_stride_W
,
// btm output tx buffer
T
*
bix
,
int
bix_stride_C
,
int
bix_stride_H
,
int
bix_stride_W
,
// btm input tx buffer
T
*
bih
,
int
bih_stride_C
,
int
bih_stride_H
,
int
bih_stride_W
,
// btm input halo
// dimensions
int
NC
,
int
NH
,
int
NW
,
// signals
int
*
signal1_flag
,
int
*
signal2_flag
,
int
*
wait1_flag
,
int
*
wait2_flag
)
{
// push top output halo to transfer buffer
strided_copy_kernel
<
T
,
is_HWC
>
(
tox
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
,
toh
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
,
NC
,
NH
,
NW
);
// push btm output halo to transfer buffer
strided_copy_kernel
<
T
,
is_HWC
>
(
box
,
box_stride_C
,
box_stride_H
,
box_stride_W
,
boh
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
,
NC
,
NH
,
NW
);
// signal to top and btm neigbhbors that output halos are ready to be read
// the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values
checked_signal
(
signal1_flag
,
signal2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
// pull top halo from transfer buffer in peer memory to input
wait_for
(
wait1_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
strided_copy_kernel
<
T
,
is_HWC
>
(
tih
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
,
tix
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
,
NC
,
NH
,
NW
);
clear_flag
(
wait1_flag
);
// pull btm halo from transfer buffer in peer memory to input
wait_for
(
wait2_flag
,
-
987751720
,
840868300
,
-
225529332
,
281513358
);
strided_copy_kernel
<
T
,
is_HWC
>
(
bih
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
,
bix
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
,
NC
,
NH
,
NW
);
clear_flag
(
wait2_flag
);
}
__global__
void
delay_kernel
(
int
delay_nanoseconds
,
int
*
counter
)
{
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
{
// waste time while doing something compiler can't predict, thus preventing it from optimizing away this code.
int
new_counter
=
0
;
double
elapsed
=
0
;
clock_t
start
=
clock
();
do
{
clock_t
now
=
clock
();
elapsed
=
(
double
)(
now
-
start
)
*
1e9
/
CLOCKS_PER_SEC
;
++
new_counter
;
}
while
(
elapsed
<
(
double
)
delay_nanoseconds
);
*
counter
=
new_counter
;
}
}
}
namespace
apex
{
namespace
contrib
{
namespace
peer_memory
{
int64_t
allocate_raw
(
int64_t
size
)
{
float
*
ptr
=
0L
;
cudaMalloc
(
&
ptr
,
size
);
cudaMemset
(
ptr
,
0
,
size
);
return
(
int64_t
)
ptr
;
}
void
free_raw
(
int64_t
raw
)
{
cudaFree
((
void
*
)
raw
);
}
void
zero
(
int64_t
raw
,
int64_t
size
)
{
cudaMemset
((
void
*
)
raw
,
0
,
size
);
}
at
::
Tensor
get_raw_ipc_address
(
int64_t
raw
)
{
cudaIpcMemHandle_t
mem_handle
;
CUDACHECK
(
cudaIpcGetMemHandle
(
&
mem_handle
,
(
void
*
)
raw
)
);
const
int
n
=
sizeof
(
cudaIpcMemHandle_t
);
auto
address_tensor
=
torch
::
empty
({
n
},
torch
::
dtype
(
torch
::
kUInt8
));
auto
address_tensor_p
=
address_tensor
.
data_ptr
<
uint8_t
>
();
memcpy
(
address_tensor_p
,
(
uint8_t
*
)
&
mem_handle
,
n
);
return
address_tensor
;
}
std
::
vector
<
int64_t
>
get_raw_peers
(
at
::
Tensor
ipc_addresses
,
int
peer_rank
,
int64_t
raw
)
{
int
peer_group_size
=
ipc_addresses
.
size
(
0
);
std
::
vector
<
int64_t
>
results
(
peer_group_size
);
for
(
int
i
=
0
;
i
<
peer_group_size
;
++
i
)
{
if
(
i
!=
peer_rank
)
{
cudaIpcMemHandle_t
mem_handle
;
memcpy
(
&
mem_handle
,
ipc_addresses
.
index
({
i
}).
data_ptr
<
uint8_t
>
(),
sizeof
(
cudaIpcMemHandle_t
));
void
*
p
=
0L
;
CUDACHECK
(
cudaIpcOpenMemHandle
((
void
**
)
&
p
,
mem_handle
,
cudaIpcMemLazyEnablePeerAccess
)
);
results
[
i
]
=
(
int64_t
)
p
;
}
else
{
results
[
i
]
=
(
int64_t
)
raw
;
}
}
return
results
;
}
at
::
Tensor
blob_view_half
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
)
{
return
blob_view
<
at
::
Half
>
((
at
::
Half
*
)
raw
,
shape
,
torch
::
dtype
(
torch
::
kFloat16
).
device
(
torch
::
kCUDA
),
channels_last
);
}
at
::
Tensor
blob_view_float
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
)
{
return
blob_view
<
float
>
((
float
*
)
raw
,
shape
,
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
),
channels_last
);
}
at
::
Tensor
blob_view_int
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
)
{
return
blob_view
<
int
>
((
int
*
)
raw
,
shape
,
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
),
channels_last
);
}
void
push_pull_halos_1d
(
bool
diagnostics
,
bool
explicit_nhwc
,
int
numSM
,
// number of SMs to use
at
::
Tensor
top_out_halo
,
// top output halo in sender device memory
at
::
Tensor
top_out_tx
,
// top output transfer buffer in sender peer pool memory
at
::
Tensor
top_inp_tx
,
// top input transfer buffer in top neighbor peer pool memory
at
::
Tensor
top_inp_halo
,
// top input halo in receiver device memory
at
::
Tensor
btm_out_halo
,
// btm output halo in sender device memory
at
::
Tensor
btm_out_tx
,
// btm output transfer buffer in sender peer pool memory
at
::
Tensor
btm_inp_tx
,
// btm input transfer buffer in btm neighbor peer pool memory
at
::
Tensor
btm_inp_halo
,
// btm input halo in receiver device memory
at
::
Tensor
top_signal
,
// top input signal in receiver device memory
at
::
Tensor
btm_signal
,
// btm input signal in receiver device memory
at
::
Tensor
waits
// top and btm signals for this rank
)
{
// basic checks of inputs
TORCH_CHECK
(
top_out_halo
.
is_cuda
());
TORCH_CHECK
(
top_out_tx
.
is_cuda
());
TORCH_CHECK
(
top_inp_tx
.
is_cuda
());
TORCH_CHECK
(
top_inp_halo
.
is_cuda
());
TORCH_CHECK
(
btm_out_halo
.
is_cuda
());
TORCH_CHECK
(
btm_out_tx
.
is_cuda
());
TORCH_CHECK
(
btm_inp_tx
.
is_cuda
());
TORCH_CHECK
(
btm_inp_halo
.
is_cuda
());
TORCH_CHECK
(
top_signal
.
is_cuda
());
TORCH_CHECK
(
btm_signal
.
is_cuda
());
TORCH_CHECK
(
waits
.
is_cuda
());
// shapes and strides
int
toh_N
,
toh_C
,
toh_H
,
toh_W
;
tensor_shape
(
top_out_halo
,
explicit_nhwc
,
toh_N
,
toh_C
,
toh_H
,
toh_W
);
int
tox_N
,
tox_C
,
tox_H
,
tox_W
;
tensor_shape
(
top_out_tx
,
explicit_nhwc
,
tox_N
,
tox_C
,
tox_H
,
tox_W
);
int
tix_N
,
tix_C
,
tix_H
,
tix_W
;
tensor_shape
(
top_inp_tx
,
explicit_nhwc
,
tix_N
,
tix_C
,
tix_H
,
tix_W
);
int
tih_N
,
tih_C
,
tih_H
,
tih_W
;
tensor_shape
(
top_inp_halo
,
explicit_nhwc
,
tih_N
,
tih_C
,
tih_H
,
tih_W
);
TORCH_CHECK
(
(
toh_N
==
tox_N
&&
tox_N
==
tix_N
&&
tix_N
==
tih_N
)
&&
(
toh_C
==
tox_C
&&
tox_C
==
tix_C
&&
tix_C
==
tih_C
)
&&
(
toh_H
==
tox_H
&&
tox_H
==
tix_H
&&
tix_H
==
tih_H
)
&&
(
toh_W
==
tox_W
&&
tox_W
==
tix_W
&&
tix_W
==
tih_W
));
int
boh_N
,
boh_C
,
boh_H
,
boh_W
;
tensor_shape
(
btm_out_halo
,
explicit_nhwc
,
boh_N
,
boh_C
,
boh_H
,
boh_W
);
int
box_N
,
box_C
,
box_H
,
box_W
;
tensor_shape
(
btm_out_tx
,
explicit_nhwc
,
box_N
,
box_C
,
box_H
,
box_W
);
int
bix_N
,
bix_C
,
bix_H
,
bix_W
;
tensor_shape
(
btm_inp_tx
,
explicit_nhwc
,
bix_N
,
bix_C
,
bix_H
,
bix_W
);
int
bih_N
,
bih_C
,
bih_H
,
bih_W
;
tensor_shape
(
btm_inp_halo
,
explicit_nhwc
,
bih_N
,
bih_C
,
bih_H
,
bih_W
);
TORCH_CHECK
(
(
boh_N
==
box_N
&&
box_N
==
bix_N
&&
bix_N
==
bih_N
)
&&
(
boh_C
==
box_C
&&
box_C
==
bix_C
&&
bix_C
==
bih_C
)
&&
(
boh_H
==
box_H
&&
box_H
==
bix_H
&&
bix_H
==
bih_H
)
&&
(
boh_W
==
box_W
&&
box_W
==
bix_W
&&
bix_W
==
bih_W
));
TORCH_CHECK
(
(
toh_N
==
boh_N
)
&&
(
toh_C
==
boh_C
)
&&
(
toh_H
==
boh_H
)
&&
(
toh_W
==
boh_W
));
int
NC
=
toh_C
,
NH
=
toh_H
,
NW
=
toh_W
;
if
(
diagnostics
)
printf
(
"NC=%d, NH=%d, NW=%d
\n
"
,
NC
,
NH
,
NW
);
int
toh_stride_N
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
;
tensor_strides
(
top_out_halo
,
explicit_nhwc
,
toh_stride_N
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
);
int
tox_stride_N
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
;
tensor_strides
(
top_out_tx
,
explicit_nhwc
,
tox_stride_N
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
);
int
tix_stride_N
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
;
tensor_strides
(
top_inp_tx
,
explicit_nhwc
,
tix_stride_N
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
);
int
tih_stride_N
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
;
tensor_strides
(
top_inp_halo
,
explicit_nhwc
,
tih_stride_N
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
);
int
boh_stride_N
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
;
tensor_strides
(
btm_out_halo
,
explicit_nhwc
,
boh_stride_N
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
);
int
box_stride_N
,
box_stride_C
,
box_stride_H
,
box_stride_W
;
tensor_strides
(
btm_out_tx
,
explicit_nhwc
,
box_stride_N
,
box_stride_C
,
box_stride_H
,
box_stride_W
);
int
bix_stride_N
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
;
tensor_strides
(
btm_inp_tx
,
explicit_nhwc
,
bix_stride_N
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
);
int
bih_stride_N
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
;
tensor_strides
(
btm_inp_halo
,
explicit_nhwc
,
bih_stride_N
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
);
// determine if nhwc
auto
is_nhwc
=
(
toh_stride_C
==
1
)
?
true
:
false
;
if
(
diagnostics
)
printf
(
"is_nhwc = %s
\n
"
,
is_nhwc
?
"true"
:
"false"
);
// figure out launch parameters
int
device
;
cudaGetDevice
(
&
device
);
cudaDeviceProp
prop
;
cudaGetDeviceProperties
(
&
prop
,
device
);
assert
(
numSM
>
0
&&
numSM
<=
prop
.
multiProcessorCount
);
auto
current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
int
numThreads
=
128
;
dim3
block
(
numThreads
,
1
,
1
);
AT_DISPATCH_ALL_TYPES_AND
(
at
::
ScalarType
::
Half
,
top_out_halo
.
scalar_type
(),
"push_pull_halos_1d_kernel"
,
[
&
]{
if
(
diagnostics
)
printf
(
"size(scalar_t) = %ld
\n
"
,
sizeof
(
scalar_t
));
scalar_t
*
toh_p
=
top_out_halo
.
data_ptr
<
scalar_t
>
();
scalar_t
*
tox_p
=
top_out_tx
.
data_ptr
<
scalar_t
>
();
scalar_t
*
tix_p
=
top_inp_tx
.
data_ptr
<
scalar_t
>
();
scalar_t
*
tih_p
=
top_inp_halo
.
data_ptr
<
scalar_t
>
();
scalar_t
*
boh_p
=
btm_out_halo
.
data_ptr
<
scalar_t
>
();
scalar_t
*
box_p
=
btm_out_tx
.
data_ptr
<
scalar_t
>
();
scalar_t
*
bix_p
=
btm_inp_tx
.
data_ptr
<
scalar_t
>
();
scalar_t
*
bih_p
=
btm_inp_halo
.
data_ptr
<
scalar_t
>
();
if
(
diagnostics
)
printf
(
"waypoint1
\n
"
);
int
*
top_signal_p
=
top_signal
.
data_ptr
<
int
>
()
+
4
;
int
*
btm_signal_p
=
btm_signal
.
data_ptr
<
int
>
();
int
*
top_wait_p
=
waits
.
data_ptr
<
int
>
();
int
*
btm_wait_p
=
waits
.
data_ptr
<
int
>
()
+
4
;
if
(
diagnostics
)
printf
(
"waypoint2
\n
"
);
// do int4 vector loads if channel count permits
int
elem_size_in_bytes
=
toh_C
*
sizeof
(
scalar_t
);
int
elem_size_in_int4
=
(
elem_size_in_bytes
/
16
);
if
(
diagnostics
)
printf
(
"elem_size_in_bytes = %d, elem_size_in_int4 = %d
\n
"
,
elem_size_in_bytes
,
elem_size_in_int4
);
if
(
is_nhwc
&&
elem_size_in_int4
*
16
==
elem_size_in_bytes
)
{
// can do int4 transfers
int
divisor
=
toh_C
/
elem_size_in_int4
;
if
(
diagnostics
)
printf
(
"CAN DO INT4 :: divisor = %d
\n
"
,
divisor
);
toh_stride_N
/=
divisor
;
toh_stride_H
/=
divisor
;
toh_stride_W
/=
divisor
;
tox_stride_N
/=
divisor
;
tox_stride_H
/=
divisor
;
tox_stride_W
/=
divisor
;
tix_stride_N
/=
divisor
;
tix_stride_H
/=
divisor
;
tix_stride_W
/=
divisor
;
tih_stride_N
/=
divisor
;
tih_stride_H
/=
divisor
;
tih_stride_W
/=
divisor
;
boh_stride_N
/=
divisor
;
boh_stride_H
/=
divisor
;
boh_stride_W
/=
divisor
;
box_stride_N
/=
divisor
;
box_stride_H
/=
divisor
;
box_stride_W
/=
divisor
;
bix_stride_N
/=
divisor
;
bix_stride_H
/=
divisor
;
bix_stride_W
/=
divisor
;
bih_stride_N
/=
divisor
;
bih_stride_H
/=
divisor
;
bih_stride_W
/=
divisor
;
NC
/=
divisor
;
if
(
diagnostics
)
{
printf
(
"divisor=%d
\n
"
,
divisor
);
printf
(
"toh_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
toh_stride_N
,
toh_stride_C
,
toh_stride_H
,
toh_stride_W
);
printf
(
"tox_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
tox_stride_N
,
tox_stride_C
,
tox_stride_H
,
tox_stride_W
);
printf
(
"tix_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
tix_stride_N
,
tix_stride_C
,
tix_stride_H
,
tix_stride_W
);
printf
(
"tih_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
tih_stride_N
,
tih_stride_C
,
tih_stride_H
,
tih_stride_W
);
printf
(
"boh_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
boh_stride_N
,
boh_stride_C
,
boh_stride_H
,
boh_stride_W
);
printf
(
"box_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
box_stride_N
,
box_stride_C
,
box_stride_H
,
box_stride_W
);
printf
(
"bix_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
bix_stride_N
,
bix_stride_C
,
bix_stride_H
,
bix_stride_W
);
printf
(
"bih_stride :: N=%d, C=%d, H=%d, W=%d
\n
"
,
bih_stride_N
,
bih_stride_C
,
bih_stride_H
,
bih_stride_W
);
printf
(
"NC=%d, NH=%d, NW=%d
\n
"
,
NC
,
NH
,
NW
);
}
void
*
kernelArgs
[]
=
{
(
int4
**
)
&
toh_p
,
&
toh_stride_C
,
&
toh_stride_H
,
&
toh_stride_W
,
(
int4
**
)
&
tox_p
,
&
tox_stride_C
,
&
tox_stride_H
,
&
tox_stride_W
,
(
int4
**
)
&
tix_p
,
&
tix_stride_C
,
&
tix_stride_H
,
&
tix_stride_W
,
(
int4
**
)
&
tih_p
,
&
tih_stride_C
,
&
tih_stride_H
,
&
tih_stride_W
,
(
int4
**
)
&
boh_p
,
&
boh_stride_C
,
&
boh_stride_H
,
&
boh_stride_W
,
(
int4
**
)
&
box_p
,
&
box_stride_C
,
&
box_stride_H
,
&
box_stride_W
,
(
int4
**
)
&
bix_p
,
&
bix_stride_C
,
&
bix_stride_H
,
&
bix_stride_W
,
(
int4
**
)
&
bih_p
,
&
bih_stride_C
,
&
bih_stride_H
,
&
bih_stride_W
,
&
NC
,
&
NH
,
&
NW
,
&
top_signal_p
,
&
btm_signal_p
,
&
top_wait_p
,
&
btm_wait_p
};
int
numBlocksPerSm
;
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
int4
,
true
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
int4
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
}
else
{
// cannot do int4 transfers
if
(
diagnostics
)
printf
(
"CAN NOT DO INT4
\n
"
);
void
*
kernelArgs
[]
=
{
&
toh_p
,
&
toh_stride_C
,
&
toh_stride_H
,
&
toh_stride_W
,
&
tox_p
,
&
tox_stride_C
,
&
tox_stride_H
,
&
tox_stride_W
,
&
tix_p
,
&
tix_stride_C
,
&
tix_stride_H
,
&
tix_stride_W
,
&
tih_p
,
&
tih_stride_C
,
&
tih_stride_H
,
&
tih_stride_W
,
&
boh_p
,
&
boh_stride_C
,
&
boh_stride_H
,
&
boh_stride_W
,
&
box_p
,
&
box_stride_C
,
&
box_stride_H
,
&
box_stride_W
,
&
bix_p
,
&
bix_stride_C
,
&
bix_stride_H
,
&
bix_stride_W
,
&
bih_p
,
&
bih_stride_C
,
&
bih_stride_H
,
&
bih_stride_W
,
&
NC
,
&
NH
,
&
NW
,
&
top_signal_p
,
&
btm_signal_p
,
&
top_wait_p
,
&
btm_wait_p
};
int
numBlocksPerSm
;
if
(
is_nhwc
)
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
true
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
true
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
}
else
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
numBlocksPerSm
,
push_pull_halos_1d_kernel
<
scalar_t
,
false
>
,
numThreads
,
0
);
dim3
grid
(
numSM
*
numBlocksPerSm
,
1
,
1
);
cudaLaunchCooperativeKernel
((
void
*
)
push_pull_halos_1d_kernel
<
scalar_t
,
false
>
,
grid
,
block
,
kernelArgs
,
0
,
current_stream
);
}
}
}
);
}
}
}
}
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh
0 → 100644
View file @
96850dfa
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <torch/extension.h>
#ifndef _peer_memory_h_
#define _peer_memory_h_
namespace
apex
{
namespace
contrib
{
namespace
peer_memory
{
int64_t
allocate_raw
(
int64_t
size
);
void
free_raw
(
int64_t
raw
);
void
zero
(
int64_t
raw
,
int64_t
size
);
at
::
Tensor
get_raw_ipc_address
(
int64_t
raw
);
std
::
vector
<
int64_t
>
get_raw_peers
(
at
::
Tensor
ipc_addresses
,
int
peer_rank
,
int64_t
raw
);
at
::
Tensor
blob_view_half
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
);
at
::
Tensor
blob_view_float
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
);
at
::
Tensor
blob_view_int
(
int64_t
raw
,
std
::
vector
<
int64_t
>
shape
,
bool
channels_last
);
void
push_pull_halos_1d
(
bool
diagnostics
,
bool
explicit_nhwc
,
int
numSM
,
// number of SMs to use
at
::
Tensor
top_out_halo
,
// top output halo in sender device memory
at
::
Tensor
top_out_tx
,
// top output transfer buffer in sender peer pool memory
at
::
Tensor
top_inp_tx
,
// top input transfer buffer in top neighbor peer pool memory
at
::
Tensor
top_inp_halo
,
// top input halo in receiver device memory
at
::
Tensor
btm_out_halo
,
// btm output halo in sender device memory
at
::
Tensor
btm_out_tx
,
// btm output transfer buffer in sender peer pool memory
at
::
Tensor
btm_inp_tx
,
// btm input transfer buffer in btm neighbor peer pool memory
at
::
Tensor
btm_inp_halo
,
// btm input halo in receiver device memory
at
::
Tensor
top_signal
,
// top input signal in receiver device memory
at
::
Tensor
btm_signal
,
// btm input signal in receiver device memory
at
::
Tensor
waits
// top and btm signals for this rank
);
}
}
}
#endif
apex/contrib/csrc/transducer/transducer_joint_kernel.cu
View file @
96850dfa
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
#include <torch/extension.h>
#include <torch/extension.h>
#include <ATen/AccumulateType.h>
#include <ATen/AccumulateType.h>
#if
!
def
ined(NEW
_GENERATOR_PATH
)
#ifdef
OLD
_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#include <ATen/CUDAGeneratorImpl.h>
#else
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAGeneratorImpl.h>
...
...
apex/contrib/fmha/fmha.py
View file @
96850dfa
...
@@ -32,16 +32,18 @@ import fmhalib as mha
...
@@ -32,16 +32,18 @@ import fmhalib as mha
class
FMHAFun
(
torch
.
autograd
.
Function
):
class
FMHAFun
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
p_dropout
,
max_s
,
is_training
):
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
p_dropout
,
max_s
,
is_training
,
zero_tensors
):
batch_size
=
cu_seqlens
.
numel
()
-
1
batch_size
=
cu_seqlens
.
numel
()
-
1
if
batch_size
<
4
:
if
batch_size
<
4
:
context
,
S_dmask
=
mha
.
fwd_nl
(
qkv
,
cu_seqlens
,
p_dropout
,
max_s
,
is_training
,
None
)
max_s
=
512
context
,
S_dmask
=
mha
.
fwd_nl
(
qkv
,
cu_seqlens
,
p_dropout
,
max_s
,
is_training
,
True
,
zero_tensors
,
None
)
else
:
else
:
context
,
S_dmask
=
mha
.
fwd
(
qkv
,
cu_seqlens
,
p_dropout
,
max_s
,
is_training
,
None
)
context
,
S_dmask
=
mha
.
fwd
(
qkv
,
cu_seqlens
,
p_dropout
,
max_s
,
is_training
,
False
,
zero_tensors
,
None
)
ctx
.
save_for_backward
(
qkv
,
S_dmask
)
ctx
.
save_for_backward
(
qkv
,
S_dmask
)
ctx
.
cu_seqlens
=
cu_seqlens
ctx
.
cu_seqlens
=
cu_seqlens
ctx
.
p_dropout
=
p_dropout
ctx
.
p_dropout
=
p_dropout
ctx
.
max_s
=
max_s
ctx
.
max_s
=
max_s
ctx
.
zero_tensors
=
zero_tensors
return
context
return
context
@
staticmethod
@
staticmethod
...
@@ -49,11 +51,11 @@ class FMHAFun(torch.autograd.Function):
...
@@ -49,11 +51,11 @@ class FMHAFun(torch.autograd.Function):
qkv
,
S_dmask
=
ctx
.
saved_tensors
qkv
,
S_dmask
=
ctx
.
saved_tensors
batch_size
=
ctx
.
cu_seqlens
.
numel
()
-
1
batch_size
=
ctx
.
cu_seqlens
.
numel
()
-
1
if
batch_size
<
4
:
if
batch_size
<
4
:
dqkv
,
dp
,
_
=
mha
.
bwd_nl
(
dout
,
qkv
,
S_dmask
,
ctx
.
cu_seqlens
,
ctx
.
p_dropout
,
ctx
.
max_s
)
dqkv
,
dp
,
_
=
mha
.
bwd_nl
(
dout
,
qkv
,
S_dmask
,
ctx
.
cu_seqlens
,
ctx
.
p_dropout
,
ctx
.
max_s
,
ctx
.
zero_tensors
)
else
:
else
:
dqkv
,
dp
=
mha
.
bwd
(
dout
,
qkv
,
S_dmask
,
ctx
.
cu_seqlens
,
ctx
.
p_dropout
,
ctx
.
max_s
)
dqkv
,
dp
=
mha
.
bwd
(
dout
,
qkv
,
S_dmask
,
ctx
.
cu_seqlens
,
ctx
.
p_dropout
,
ctx
.
max_s
,
ctx
.
zero_tensors
)
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
return
dqkv
,
None
,
None
,
None
,
None
,
None
class
FMHA
(
torch
.
nn
.
Module
):
class
FMHA
(
torch
.
nn
.
Module
):
...
@@ -67,8 +69,8 @@ class FMHA(torch.nn.Module):
...
@@ -67,8 +69,8 @@ class FMHA(torch.nn.Module):
self
.
d
=
self
.
hidden_size
//
self
.
h
self
.
d
=
self
.
hidden_size
//
self
.
h
assert
self
.
d
*
self
.
h
==
self
.
hidden_size
,
"Invalid hidden size/num_heads"
assert
self
.
d
*
self
.
h
==
self
.
hidden_size
,
"Invalid hidden size/num_heads"
def
forward
(
self
,
qkv
,
cu_seqlens
,
max_s
,
is_training
=
True
):
def
forward
(
self
,
qkv
,
cu_seqlens
,
max_s
,
is_training
=
True
,
zero_tensors
=
False
):
ctx
=
FMHAFun
.
apply
(
qkv
.
view
(
-
1
,
3
,
self
.
h
,
self
.
d
),
cu_seqlens
,
self
.
p_dropout
,
max_s
,
is_training
)
ctx
=
FMHAFun
.
apply
(
qkv
.
view
(
-
1
,
3
,
self
.
h
,
self
.
d
),
cu_seqlens
,
self
.
p_dropout
,
max_s
,
is_training
,
zero_tensors
)
return
ctx
.
view
(
-
1
,
self
.
hidden_size
)
return
ctx
.
view
(
-
1
,
self
.
hidden_size
)
apex/contrib/focal_loss/__init__.py
0 → 100644
View file @
96850dfa
try
:
import
torch
import
focal_loss_cuda
from
.focal_loss
import
focal_loss
del
torch
del
focal_loss_cuda
del
focal_loss
except
ImportError
as
err
:
print
(
"apex was installed without --focal_loss flag, apex.contrib.focal_loss is not available"
)
apex/contrib/focal_loss/focal_loss.py
0 → 100644
View file @
96850dfa
import
torch
import
focal_loss_cuda
class
FocalLoss
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
cls_output
,
cls_targets_at_level
,
num_positives_sum
,
num_real_classes
,
alpha
,
gamma
,
label_smoothing
=
0.0
,
):
loss
,
partial_grad
=
focal_loss_cuda
.
forward
(
cls_output
,
cls_targets_at_level
,
num_positives_sum
,
num_real_classes
,
alpha
,
gamma
,
label_smoothing
,
)
ctx
.
save_for_backward
(
partial_grad
,
num_positives_sum
)
return
loss
@
staticmethod
def
backward
(
ctx
,
grad_loss
):
partial_grad
,
num_positives_sum
=
ctx
.
saved_tensors
# The backward kernel is actually in-place to save memory space,
# partial_grad and grad_input are the same tensor.
grad_input
=
focal_loss_cuda
.
backward
(
grad_loss
,
partial_grad
,
num_positives_sum
)
return
grad_input
,
None
,
None
,
None
,
None
,
None
,
None
def
focal_loss
(
cls_output
:
torch
.
Tensor
,
cls_targets_at_level
:
torch
.
Tensor
,
num_positive_sum
:
torch
.
Tensor
,
num_real_classes
:
int
,
alpha
:
float
,
gamma
:
float
,
label_smoothing
:
float
=
0.0
,
)
->
torch
.
Tensor
:
"""Fused focal loss function."""
return
FocalLoss
.
apply
(
cls_output
,
cls_targets_at_level
,
num_positive_sum
,
num_real_classes
,
alpha
,
gamma
,
label_smoothing
,
)
Prev
1
2
3
4
5
6
7
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment