Unverified Commit 96850dfa authored by Jithun Nair's avatar Jithun Nair Committed by GitHub
Browse files

Merge pull request #80 from ROCmSoftwarePlatform/IFU-master-2022-07-29

IFU-master-2022-07-29
parents 87fc4125 cc5f83b5
...@@ -166,6 +166,12 @@ REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); ...@@ -166,6 +166,12 @@ REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4); REGISTER_BWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4);
REGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4);
......
...@@ -67,9 +67,105 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params ...@@ -67,9 +67,105 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params
} }
REGISTER_FWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4);
REGISTER_FWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4);
REGISTER_FWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 4);
REGISTER_FWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4);
REGISTER_FWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 4);
REGISTER_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4);
REGISTER_FWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4);
REGISTER_FWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 2, 1, 4, 4);
REGISTER_FWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4);
REGISTER_FWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 2, 1, 4, 4);
REGISTER_FWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
......
...@@ -110,4 +110,4 @@ torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads, ...@@ -110,4 +110,4 @@ torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
} }
} // namespace additive_mask_softmax_dropout } // namespace additive_mask_softmax_dropout
} // namespace fused_softmax } // namespace fused_softmax
} // namespace multihead_attn } // namespace multihead_attn
\ No newline at end of file
#pragma once #pragma once
#include <ATen/ATen.h> #include <ATen/ATen.h>
#if !defined(NEW_GENERATOR_PATH) #ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h> #include <ATen/CUDAGeneratorImpl.h>
#else #else
#include <ATen/cuda/CUDAGeneratorImpl.h> #include <ATen/cuda/CUDAGeneratorImpl.h>
......
...@@ -687,5 +687,4 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -687,5 +687,4 @@ std::vector<torch::Tensor> bwd_cuda(
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
} // end namespace encdec_norm_add } // end namespace encdec_norm_add
} // end namespace multihead_attn } // end namespace multihead_attn
\ No newline at end of file
...@@ -67,7 +67,7 @@ __device__ void cuWelfordMuSigma2(const T *__restrict__ vals, const int n1, ...@@ -67,7 +67,7 @@ __device__ void cuWelfordMuSigma2(const T *__restrict__ vals, const int n1,
} }
// intra-warp reductions // intra-warp reductions
for (int l = 0; l <= 4; ++l) { for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x+(1<<l))&31; int srcLaneB = (threadIdx.x + (1 << l)) & 31;
U muB = WARP_SHFL(mu, srcLaneB, 32); U muB = WARP_SHFL(mu, srcLaneB, 32);
U countB = WARP_SHFL(count, srcLaneB, 32); U countB = WARP_SHFL(count, srcLaneB, 32);
U sigma2B = WARP_SHFL(sigma2, srcLaneB, 32); U sigma2B = WARP_SHFL(sigma2, srcLaneB, 32);
...@@ -108,7 +108,7 @@ __device__ void cuWelfordMuSigma2(const T *__restrict__ vals, const int n1, ...@@ -108,7 +108,7 @@ __device__ void cuWelfordMuSigma2(const T *__restrict__ vals, const int n1,
// don't care about final value of count, we know count == n2 // don't care about final value of count, we know count == n2
} else { } else {
mu = WARP_SHFL(mu, 0, 32); mu = WARP_SHFL(mu, 0, 32);
sigma2 = WARP_SHFL(sigma2/U(n2), 0, 32); sigma2 = WARP_SHFL(sigma2 / U(n2), 0, 32);
} }
} }
} }
...@@ -158,7 +158,7 @@ __device__ void cuWelfordMuSigma2(const at::Half *__restrict__ vals, ...@@ -158,7 +158,7 @@ __device__ void cuWelfordMuSigma2(const at::Half *__restrict__ vals,
} }
// intra-warp reductions // intra-warp reductions
for (int l = 0; l <= 4; ++l) { for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x+(1<<l))&31; int srcLaneB = (threadIdx.x + (1 << l)) & 31;
float muB = WARP_SHFL(mu, srcLaneB, 32); float muB = WARP_SHFL(mu, srcLaneB, 32);
float countB = WARP_SHFL(count, srcLaneB, 32); float countB = WARP_SHFL(count, srcLaneB, 32);
float sigma2B = WARP_SHFL(sigma2, srcLaneB, 32); float sigma2B = WARP_SHFL(sigma2, srcLaneB, 32);
...@@ -199,7 +199,7 @@ __device__ void cuWelfordMuSigma2(const at::Half *__restrict__ vals, ...@@ -199,7 +199,7 @@ __device__ void cuWelfordMuSigma2(const at::Half *__restrict__ vals,
// don't care about final value of count, we know count == n2 // don't care about final value of count, we know count == n2
} else { } else {
mu = WARP_SHFL(mu, 0, 32); mu = WARP_SHFL(mu, 0, 32);
sigma2 = WARP_SHFL(sigma2/float(n2), 0, 32); sigma2 = WARP_SHFL(sigma2 / float(n2), 0, 32);
} }
} }
} }
...@@ -521,7 +521,7 @@ cuComputeGradInput(const T *__restrict__ dout, const T *__restrict__ dout_resid, ...@@ -521,7 +521,7 @@ cuComputeGradInput(const T *__restrict__ dout, const T *__restrict__ dout_resid,
} }
} }
// intra-warp reductions // intra-warp reductions
for (int mask = blockDim.x/2; mask > 0; mask /= 2) { for (int mask = blockDim.x / 2; mask > 0; mask /= 2) {
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask, 32); sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask, 32);
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask, 32); sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask, 32);
} }
......
...@@ -19,19 +19,13 @@ namespace multihead_attn { ...@@ -19,19 +19,13 @@ namespace multihead_attn {
namespace self_bias_additive_mask { namespace self_bias_additive_mask {
namespace rocblas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
bool use_time_mask, int heads, torch::Tensor const& inputs,
bool is_training, torch::Tensor const& input_weights,
int heads, torch::Tensor const& output_weights,
torch::Tensor const& inputs, torch::Tensor const& input_biases,
torch::Tensor const& input_weights, torch::Tensor const& output_biases,
torch::Tensor const& output_weights, const half* pad_mask, float dropout_prob) {
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const half* pad_mask,
float dropout_prob
)
{
const int embed_dim = inputs.size(2); const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1); const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0); const int q_seq_len = inputs.size(0);
......
...@@ -501,5 +501,4 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -501,5 +501,4 @@ std::vector<torch::Tensor> bwd_cuda(
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
} // end namespace self } // end namespace self
} // end namespace multihead_attn } // end namespace multihead_attn
\ No newline at end of file
...@@ -577,5 +577,4 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -577,5 +577,4 @@ std::vector<torch::Tensor> bwd_cuda(
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
} // end namespace self_norm_add } // end namespace self_norm_add
} // end namespace multihead_attn } // end namespace multihead_attn
\ No newline at end of file
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include <ATen/cuda/CUDAGraphsUtils.cuh> #include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <curand_kernel.h> #include <curand_kernel.h>
#if !defined(NEW_GENERATOR_PATH) #ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h> #include <ATen/CUDAGeneratorImpl.h>
#else #else
#include <ATen/cuda/CUDAGeneratorImpl.h> #include <ATen/cuda/CUDAGeneratorImpl.h>
...@@ -1593,11 +1593,13 @@ int log2_ceil_native(int value) { ...@@ -1593,11 +1593,13 @@ int log2_ceil_native(int value) {
} }
template <typename T> template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) { __device__ __forceinline__ T
#if CUDA_VERSION >= 9000 && !defined(__HIP_PLATFORM_HCC__) WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize,
return __shfl_xor_sync(mask, value, laneMask, width); unsigned int mask = 0xffffffff) {
#if CUDA_VERSION >= 9000
return __shfl_xor_sync(mask, value, laneMask, width);
#else #else
return __shfl_xor(value, laneMask, width); return __shfl_xor(value, laneMask, width);
#endif #endif
} }
......
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nccl_p2p_cuda.cuh"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("get_unique_nccl_id", &apex::contrib::nccl_p2p::get_unique_nccl_id, "get_unique_nccl_id");
m.def("init_nccl_comm", &apex::contrib::nccl_p2p::init_nccl_comm, "init_nccl_comm");
m.def("left_right_halo_exchange_inplace", &apex::contrib::nccl_p2p::left_right_halo_exchange_inplace, "left_right_halo_exchange_inplace");
m.def("left_right_halo_exchange", &apex::contrib::nccl_p2p::left_right_halo_exchange, "left_right_halo_exchange");
m.def("add_delay", &apex::contrib::nccl_p2p::add_delay, "add_delay");
}
#include <torch/extension.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <ATen/cuda/CUDAContext.h>
#include <list>
#include <cstdio>
#include <ctime>
#include <cassert>
#include "nccl.h"
/*
* This file implements a crude but effective mechanism for copying data between tenors owned by different ranks
* on the same machine using cudaMemcpyAsync peer-to-peer transfers.
*/
namespace {
__global__ void AddDelay_kernel(const int delay, int* counter) {
if (blockIdx.x == 0 && threadIdx.x == 0) {
// waste time while doing something compiler can't predict, thus preventing it from optimizing away this code.
int new_counter = 0;
double elapsed = 0;
clock_t start = clock();
do {
clock_t now = clock();
elapsed = (double)(now - start)*1e9 / CLOCKS_PER_SEC;
++new_counter;
} while (elapsed < (double)delay);
*counter = new_counter;
}
}
class NcclCommWrapper
{
private:
ncclComm_t comm;
int rank, world_size;
ncclDataType_t get_nccl_type(at::Tensor input)
{
switch (input.scalar_type())
{
case at::ScalarType::Half:
return ncclFloat16;
case at::ScalarType::Float:
return ncclFloat32;
case at::ScalarType::Double:
return ncclFloat64;
case at::ScalarType::Byte:
return ncclUint8;
case at::ScalarType::Char:
return ncclInt8;
case at::ScalarType::Int:
return ncclInt32;
case at::ScalarType::Long:
return ncclInt64;
case at::ScalarType::BFloat16:
return ncclBfloat16;
default:
assert(false);
}
}
public:
NcclCommWrapper()
{
memset(&comm, 0, sizeof(ncclComm_t));
rank = 0;
world_size = 0;
}
NcclCommWrapper(ncclUniqueId id, int my_rank, int num_ranks)
{
ncclCommInitRank(&comm, num_ranks, id, my_rank);
rank = my_rank;
world_size = num_ranks;
}
~NcclCommWrapper()
{
printf("ncclCommDestroy()\n");
ncclCommDestroy(comm);
}
void left_right_halo_exchange_inplace(int left_rank, int right_rank, at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo)
{
auto stream = at::cuda::getCurrentCUDAStream();
ncclGroupStart();
ncclDataType_t ncclType = get_nccl_type(left_output_halo);
bool left_zero = (left_rank < 0);
bool right_zero = (right_rank < 0);
size_t left_n = torch::numel(left_output_halo);
size_t right_n = torch::numel(right_output_halo);
assert(left_n > 0 && left_n == right_n);
if (left_zero) {
left_input_halo.zero_();
} else {
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, left_output_halo.scalar_type(), "left_halo_exch", [&]() {
// send left (to my_rank - 1)
ncclSend(left_output_halo.data_ptr<scalar_t>(), left_n, ncclType, left_rank, comm, stream);
// receive left (from my_rank - 1)
ncclRecv(left_input_halo.data_ptr<scalar_t>(), right_n, ncclType, left_rank, comm, stream);
});
}
if (right_zero) {
right_input_halo.zero_();
} else {
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, right_output_halo.scalar_type(), "right_halo_exch", [&]() {
// send right (to my_rank + 1 )
ncclSend(right_output_halo.data_ptr<scalar_t>(), right_n, ncclType, right_rank, comm, stream);
// receive right (from my_rank + 1)
ncclRecv(right_input_halo.data_ptr<scalar_t>(), left_n, ncclType, right_rank, comm, stream);
});
}
ncclGroupEnd();
}
std::vector<at::Tensor> left_right_halo_exchange(int left_rank, int right_rank, at::Tensor left_output_halo, at::Tensor right_output_halo)
{
// after halo exchange:
// left_output_halo of rank+1 ends up in right_input_halo of rank
// right_output_halo of rank-1 ends up in left_input_halo of rank
auto right_input_halo = torch::empty_like(left_output_halo);
auto left_input_halo = torch::empty_like(right_output_halo);
left_right_halo_exchange_inplace(left_rank, right_rank, left_output_halo, right_output_halo, left_input_halo, right_input_halo);
return {left_input_halo, right_input_halo};
}
};
class ManagedObjects
{
public:
ManagedObjects()
{
}
~ManagedObjects()
{
for (auto it = _nccl_comms.begin(); it != _nccl_comms.end(); ++it)
{
delete *it;
}
}
int add_comm(NcclCommWrapper* comm)
{
int handle = _nccl_comms.size();
_nccl_comms.push_back(comm);
return handle;
}
NcclCommWrapper& get_comm(int handle)
{
assert(handle >= 0 && handle < _nccl_comms.size());
return *_nccl_comms[handle];
}
private:
std::vector<NcclCommWrapper*> _nccl_comms;
};
class ManagedObjects mo;
} // end anonymous namespace
namespace apex { namespace contrib { namespace nccl_p2p {
at::Tensor get_unique_nccl_id(int n)
{
ncclUniqueId id;
ncclGetUniqueId(&id);
auto id_tensor = torch::empty({n,(int)sizeof(ncclUniqueId)}, torch::dtype(torch::kUInt8).device(torch::kCPU).requires_grad(false));
auto id_ptr = id_tensor.data_ptr<uint8_t>();
size_t offset = 0;
for (int i = 0; i < n; ++i)
{
ncclUniqueId id;
ncclGetUniqueId(&id);
memcpy(id_ptr+offset, &id, sizeof(ncclUniqueId));
offset += sizeof(ncclUniqueId);
}
return id_tensor;
}
int init_nccl_comm(at::Tensor unique_nccl_id, int my_rank, int num_ranks)
{
ncclUniqueId id;
auto unique_nccl_id_ptr = unique_nccl_id.data_ptr<uint8_t>();
memcpy(&id, unique_nccl_id_ptr, sizeof(ncclUniqueId));
NcclCommWrapper* comm = new NcclCommWrapper(id, my_rank, num_ranks);
int handle = mo.add_comm(comm);
comm = 0L;
return handle;
}
void left_right_halo_exchange_inplace(int handle, int left_rank, int right_rank, at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo)
{
class NcclCommWrapper& communicator = mo.get_comm(handle);
return communicator.left_right_halo_exchange_inplace(left_rank, right_rank, left_output_halo, right_output_halo, left_input_halo, right_input_halo);
}
std::vector<at::Tensor> left_right_halo_exchange(int handle, int left_rank, int right_rank, at::Tensor left_output_halo, at::Tensor right_output_halo)
{
class NcclCommWrapper& communicator = mo.get_comm(handle);
return communicator.left_right_halo_exchange(left_rank, right_rank, left_output_halo, right_output_halo);
}
void add_delay(int delay)
{
auto stream = at::cuda::getCurrentCUDAStream();
auto t = torch::empty({1}, torch::dtype(torch::kInt32).device(torch::kCUDA));
AddDelay_kernel<<<1,1,0,stream>>>(delay, t.data_ptr<int>());
}
}}}
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <torch/extension.h>
#ifndef _nccl_p2p_h_
#define _nccl_p2p_h_
namespace apex { namespace contrib { namespace nccl_p2p {
at::Tensor get_unique_nccl_id(int n);
int init_nccl_comm(
at::Tensor unique_nccl_id,
int my_rank,
int num_ranks
);
void left_right_halo_exchange_inplace(
int handle,
int left_rank,
int right_rank,
at::Tensor left_output_halo,
at::Tensor right_output_halo,
at::Tensor left_input_halo,
at::Tensor right_input_halo);
std::vector<at::Tensor> left_right_halo_exchange(
int handle,
int left_rank,
int right_rank,
at::Tensor left_output_halo,
at::Tensor right_output_halo);
void add_delay(int delay);
}}}
#endif
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "peer_memory_cuda.cuh"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("allocate_raw", &apex::contrib::peer_memory::allocate_raw, "allocate_raw");
m.def("free_raw", &apex::contrib::peer_memory::free_raw, "free_raw");
m.def("zero", &apex::contrib::peer_memory::zero, "zero");
m.def("get_raw_ipc_address", &apex::contrib::peer_memory::get_raw_ipc_address, "get_raw_ipc_address");
m.def("get_raw_peers", &apex::contrib::peer_memory::get_raw_peers, "get_raw_peers");
m.def("blob_view_half", &apex::contrib::peer_memory::blob_view_half, "blob_view_half");
m.def("blob_view_float", &apex::contrib::peer_memory::blob_view_float, "blob_view_float");
m.def("blob_view_int", &apex::contrib::peer_memory::blob_view_int, "blob_view_int");
m.def("push_pull_halos_1d", &apex::contrib::peer_memory::push_pull_halos_1d, "push_pull_halos_1d");
}
#include <torch/extension.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <ATen/cuda/CUDAContext.h>
#include <list>
#include <cstdio>
#include <cassert>
#include <cuda_runtime_api.h>
#include <cooperative_groups.h>
#include "nccl.h"
namespace cg = cooperative_groups;
#define CUDACHECK(cmd) do { \
cudaError_t err = cmd; \
if( err != cudaSuccess ) { \
char hostname[1024]; \
gethostname(hostname, 1024); \
printf("%s: CUDA failure %s:%d '%s'\n", \
hostname, \
__FILE__,__LINE__,cudaGetErrorString(err)); \
} \
} while(0)
namespace {
/* Basic deleter function for from_blob function.
void deleter(void* ptr)
{
printf("deleter(ptr=%p)\n",ptr);
cudaFree(ptr);
}
*/
template<class T>
at::Tensor blob_view(T* raw_ptr, std::vector<int64_t> shape, const at::TensorOptions& options, bool channels_last)
{
size_t size = 1;
std::vector<int64_t> strides(shape.size());
if (channels_last) {
assert(shape.size() == 4);
strides[0] = shape[1]*shape[2]*shape[3];
strides[1] = 1;
strides[2] = shape[1]*shape[3];
strides[3] = shape[1];
} else {
int idx = strides.size();
for (auto it = shape.rbegin(); it != shape.rend(); ++it)
{
strides[--idx] = size;
size *= *it;
}
}
size *= sizeof(T);
// TODO: Implement dynamic reuse of pooled peer memory.
// We provide no deleter function because all peer memory allocations are static in this implementation.
return torch::from_blob((void*)raw_ptr, shape, strides, 0L, options);
}
void tensor_shape(at::Tensor t, bool explicit_nhwc, int& N, int& C, int& H, int& W)
{
if (t.dim() == 3) {
N = 1;
if (explicit_nhwc) {
C = t.size(2);
H = t.size(0);
W = t.size(1);
} else {
C = t.size(0);
H = t.size(1);
W = t.size(2);
}
} else if (t.dim() == 4) {
if (explicit_nhwc) {
N = t.size(0);
C = t.size(3);
H = t.size(1);
W = t.size(2);
} else {
N = t.size(0);
C = t.size(1);
H = t.size(2);
W = t.size(3);
}
} else {
printf("%s;%d - t.dim() must be either 3 or 4 (was %d)\n",__FILE__,__LINE__,t.dim());
assert(t.dim() == 3 || t.dim() == 4);
}
}
void tensor_strides(at::Tensor t, bool explicit_nhwc, int& stride_N, int& stride_C, int& stride_H, int& stride_W)
{
if (t.dim() == 3) {
if (explicit_nhwc) {
stride_C = t.stride(2);
stride_H = t.stride(0);
stride_W = t.stride(1);
} else {
stride_C = t.stride(0);
stride_H = t.stride(1);
stride_W = t.stride(2);
}
stride_N = t.size(0)*t.size(1)*t.size(2);
} else if (t.dim() == 4) {
if (explicit_nhwc) {
stride_N = t.stride(0);
stride_C = t.stride(3);
stride_H = t.stride(1);
stride_W = t.stride(2);
} else {
stride_N = t.stride(0);
stride_C = t.stride(1);
stride_H = t.stride(2);
stride_W = t.stride(3);
}
} else {
printf("%s;%d - t.dim() must be either 3 or 4 (was %d)\n",__FILE__,__LINE__,t.dim());
assert(t.dim() == 3 || t.dim() == 4);
}
}
template<class T, bool is_HWC>
__device__ void strided_copy_kernel(
T* dst, const int dst_stride_C, const int dst_stride_H, const int dst_stride_W,
const T* src, const int src_stride_C, const int src_stride_H, const int src_stride_W,
const int NC, const int NH, const int NW
)
{
size_t tot_num_threads = gridDim.x * blockDim.x;
size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
const size_t count = NC*NH*NW;
for (size_t i = thread_id; i < count; i += tot_num_threads)
{
size_t c,h,w;
if (is_HWC) {
c = i % NC;
w = i / NC;
h = w / NW;
w = w % NW;
}
else {
w = i % NW;
h = i / NW;
c = h / NH;
h = h % NH;
}
size_t dst_off = c*dst_stride_C + h*dst_stride_H + w*dst_stride_W;
size_t src_off = c*src_stride_C + h*src_stride_H + w*src_stride_W;
dst[dst_off] = src[src_off];
}
}
__device__ void checked_signal(
volatile int* signal1_flag, volatile int* signal2_flag,
const int v1, const int v2, const int v3, const int v4
)
{
cg::this_grid().sync();
bool is_main_thread = (blockIdx.x == 0 && threadIdx.x == 0) ? true : false;
if (is_main_thread) {
// flush all writes to global memory
__threadfence_system();
// wait for top or bottom neighbor to clear signal
register int r1, r2, r3, r4;
bool top_zeroed=false, btm_zeroed=false, top_done=false, btm_done=false;
do {
do {
if (!top_zeroed) {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal1_flag) : "memory");
if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) top_zeroed = true;
}
if (!btm_zeroed) {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal2_flag) : "memory");
if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) btm_zeroed = true;
}
} while((top_zeroed == top_done) && (btm_zeroed == btm_done));
if (!top_done && top_zeroed) {
// signal to top neighbor my output is ready
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal1_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
top_done = true;
}
if (!btm_done && btm_zeroed) {
// signal to bottom neighbor my output is ready
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal2_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
btm_done = true;
}
} while (!top_done || !btm_done);
}
}
__device__ void wait_for(
volatile int* wait_flag,
const int v1, const int v2, const int v3, const int v4
)
{
bool is_main_thread = (blockIdx.x == 0 && threadIdx.x == 0) ? true : false;
if (is_main_thread) {
register int r1, r2, r3, r4;
// wait for senders to signal their output is read
do {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(wait_flag) : "memory");
} while (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4);
}
cg::this_grid().sync(); // all threads wait for main
}
__device__ void clear_flag(
volatile int* wait_flag
)
{
cg::this_grid().sync(); // wait for all threads in kernel to finish
bool is_main_thread = (blockIdx.x == 0 && threadIdx.x == 0) ? true : false;
if (is_main_thread) {
register int r1, r2, r3, r4;
r1 = 0; r2 = 0; r3 = 0; r4 = 0;
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(wait_flag), "r"(r1), "r"(r2), "r"(r3), "r"(r4) : "memory");
}
}
template<class T, bool is_HWC>
#if __CUDA_ARCH__ >= 700
__launch_bounds__(128, 16)
#endif
__global__ void push_pull_halos_1d_kernel(
// top halo,
const T* toh, int toh_stride_C, int toh_stride_H, int toh_stride_W, // top output halo
T* tox, int tox_stride_C, int tox_stride_H, int tox_stride_W, // top output tx buffer
T* tix, int tix_stride_C, int tix_stride_H, int tix_stride_W, // top input tx buffer
T* tih, int tih_stride_C, int tih_stride_H, int tih_stride_W, // top input halo
// btm halo
const T* boh, int boh_stride_C, int boh_stride_H, int boh_stride_W, // btm output halo
T* box, int box_stride_C, int box_stride_H, int box_stride_W, // btm output tx buffer
T* bix, int bix_stride_C, int bix_stride_H, int bix_stride_W, // btm input tx buffer
T* bih, int bih_stride_C, int bih_stride_H, int bih_stride_W, // btm input halo
// dimensions
int NC, int NH, int NW,
// signals
int* signal1_flag,
int* signal2_flag,
int* wait1_flag,
int* wait2_flag
)
{
// push top output halo to transfer buffer
strided_copy_kernel<T,is_HWC>(tox, tox_stride_C, tox_stride_H, tox_stride_W, toh, toh_stride_C, toh_stride_H, toh_stride_W, NC, NH, NW);
// push btm output halo to transfer buffer
strided_copy_kernel<T,is_HWC>(box, box_stride_C, box_stride_H, box_stride_W, boh, boh_stride_C, boh_stride_H, boh_stride_W, NC, NH, NW);
// signal to top and btm neigbhbors that output halos are ready to be read
// the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values
checked_signal(signal1_flag, signal2_flag, -987751720, 840868300, -225529332, 281513358);
// pull top halo from transfer buffer in peer memory to input
wait_for(wait1_flag, -987751720, 840868300, -225529332, 281513358);
strided_copy_kernel<T,is_HWC>(tih, tih_stride_C, tih_stride_H, tih_stride_W, tix, tix_stride_C, tix_stride_H, tix_stride_W, NC, NH, NW);
clear_flag(wait1_flag);
// pull btm halo from transfer buffer in peer memory to input
wait_for(wait2_flag, -987751720, 840868300, -225529332, 281513358);
strided_copy_kernel<T,is_HWC>(bih, bih_stride_C, bih_stride_H, bih_stride_W, bix, bix_stride_C, bix_stride_H, bix_stride_W, NC, NH, NW);
clear_flag(wait2_flag);
}
__global__ void delay_kernel(int delay_nanoseconds, int* counter)
{
if (blockIdx.x == 0 && threadIdx.x == 0) {
// waste time while doing something compiler can't predict, thus preventing it from optimizing away this code.
int new_counter = 0;
double elapsed = 0;
clock_t start = clock();
do {
clock_t now = clock();
elapsed = (double)(now - start)*1e9 / CLOCKS_PER_SEC;
++new_counter;
} while (elapsed < (double)delay_nanoseconds);
*counter = new_counter;
}
}
}
namespace apex { namespace contrib { namespace peer_memory {
int64_t allocate_raw(int64_t size)
{
float* ptr = 0L;
cudaMalloc(&ptr, size);
cudaMemset(ptr, 0, size);
return (int64_t)ptr;
}
void free_raw(int64_t raw)
{
cudaFree((void*)raw);
}
void zero(int64_t raw, int64_t size)
{
cudaMemset((void*)raw, 0, size);
}
at::Tensor get_raw_ipc_address(int64_t raw)
{
cudaIpcMemHandle_t mem_handle;
CUDACHECK( cudaIpcGetMemHandle(&mem_handle, (void*)raw) );
const int n = sizeof(cudaIpcMemHandle_t);
auto address_tensor = torch::empty({n}, torch::dtype(torch::kUInt8));
auto address_tensor_p = address_tensor.data_ptr<uint8_t>();
memcpy(address_tensor_p, (uint8_t*)&mem_handle, n);
return address_tensor;
}
std::vector<int64_t> get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int64_t raw)
{
int peer_group_size = ipc_addresses.size(0);
std::vector<int64_t> results(peer_group_size);
for (int i = 0; i < peer_group_size; ++i) {
if (i != peer_rank) {
cudaIpcMemHandle_t mem_handle;
memcpy(&mem_handle, ipc_addresses.index({i}).data_ptr<uint8_t>(), sizeof(cudaIpcMemHandle_t));
void* p = 0L;
CUDACHECK( cudaIpcOpenMemHandle((void**)&p, mem_handle, cudaIpcMemLazyEnablePeerAccess) );
results[i] = (int64_t)p;
} else {
results[i] = (int64_t)raw;
}
}
return results;
}
at::Tensor blob_view_half(int64_t raw, std::vector<int64_t> shape, bool channels_last)
{
return blob_view<at::Half>((at::Half*)raw, shape, torch::dtype(torch::kFloat16).device(torch::kCUDA), channels_last);
}
at::Tensor blob_view_float(int64_t raw, std::vector<int64_t> shape, bool channels_last)
{
return blob_view<float>((float*)raw, shape, torch::dtype(torch::kFloat32).device(torch::kCUDA), channels_last);
}
at::Tensor blob_view_int(int64_t raw, std::vector<int64_t> shape, bool channels_last)
{
return blob_view<int>((int*)raw, shape, torch::dtype(torch::kInt32).device(torch::kCUDA), channels_last);
}
void push_pull_halos_1d(
bool diagnostics,
bool explicit_nhwc,
int numSM, // number of SMs to use
at::Tensor top_out_halo, // top output halo in sender device memory
at::Tensor top_out_tx, // top output transfer buffer in sender peer pool memory
at::Tensor top_inp_tx, // top input transfer buffer in top neighbor peer pool memory
at::Tensor top_inp_halo, // top input halo in receiver device memory
at::Tensor btm_out_halo, // btm output halo in sender device memory
at::Tensor btm_out_tx, // btm output transfer buffer in sender peer pool memory
at::Tensor btm_inp_tx, // btm input transfer buffer in btm neighbor peer pool memory
at::Tensor btm_inp_halo, // btm input halo in receiver device memory
at::Tensor top_signal, // top input signal in receiver device memory
at::Tensor btm_signal, // btm input signal in receiver device memory
at::Tensor waits // top and btm signals for this rank
)
{
// basic checks of inputs
TORCH_CHECK(top_out_halo.is_cuda());
TORCH_CHECK(top_out_tx.is_cuda());
TORCH_CHECK(top_inp_tx.is_cuda());
TORCH_CHECK(top_inp_halo.is_cuda());
TORCH_CHECK(btm_out_halo.is_cuda());
TORCH_CHECK(btm_out_tx.is_cuda());
TORCH_CHECK(btm_inp_tx.is_cuda());
TORCH_CHECK(btm_inp_halo.is_cuda());
TORCH_CHECK(top_signal.is_cuda());
TORCH_CHECK(btm_signal.is_cuda());
TORCH_CHECK(waits.is_cuda());
// shapes and strides
int toh_N, toh_C, toh_H, toh_W;
tensor_shape(top_out_halo, explicit_nhwc, toh_N, toh_C, toh_H, toh_W);
int tox_N, tox_C, tox_H, tox_W;
tensor_shape(top_out_tx, explicit_nhwc, tox_N, tox_C, tox_H, tox_W);
int tix_N, tix_C, tix_H, tix_W;
tensor_shape(top_inp_tx, explicit_nhwc, tix_N, tix_C, tix_H, tix_W);
int tih_N, tih_C, tih_H, tih_W;
tensor_shape(top_inp_halo, explicit_nhwc, tih_N, tih_C, tih_H, tih_W);
TORCH_CHECK(
(toh_N == tox_N && tox_N == tix_N && tix_N == tih_N) &&
(toh_C == tox_C && tox_C == tix_C && tix_C == tih_C) &&
(toh_H == tox_H && tox_H == tix_H && tix_H == tih_H) &&
(toh_W == tox_W && tox_W == tix_W && tix_W == tih_W));
int boh_N, boh_C, boh_H, boh_W;
tensor_shape(btm_out_halo, explicit_nhwc, boh_N, boh_C, boh_H, boh_W);
int box_N, box_C, box_H, box_W;
tensor_shape(btm_out_tx, explicit_nhwc, box_N, box_C, box_H, box_W);
int bix_N, bix_C, bix_H, bix_W;
tensor_shape(btm_inp_tx, explicit_nhwc, bix_N, bix_C, bix_H, bix_W);
int bih_N, bih_C, bih_H, bih_W;
tensor_shape(btm_inp_halo, explicit_nhwc, bih_N, bih_C, bih_H, bih_W);
TORCH_CHECK(
(boh_N == box_N && box_N == bix_N && bix_N == bih_N) &&
(boh_C == box_C && box_C == bix_C && bix_C == bih_C) &&
(boh_H == box_H && box_H == bix_H && bix_H == bih_H) &&
(boh_W == box_W && box_W == bix_W && bix_W == bih_W));
TORCH_CHECK(
(toh_N == boh_N) &&
(toh_C == boh_C) &&
(toh_H == boh_H) &&
(toh_W == boh_W));
int NC=toh_C, NH=toh_H, NW=toh_W;
if (diagnostics) printf("NC=%d, NH=%d, NW=%d\n",NC,NH,NW);
int toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W;
tensor_strides(top_out_halo, explicit_nhwc, toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W);
int tox_stride_N, tox_stride_C, tox_stride_H, tox_stride_W;
tensor_strides(top_out_tx, explicit_nhwc, tox_stride_N, tox_stride_C, tox_stride_H, tox_stride_W);
int tix_stride_N, tix_stride_C, tix_stride_H, tix_stride_W;
tensor_strides(top_inp_tx, explicit_nhwc, tix_stride_N, tix_stride_C, tix_stride_H, tix_stride_W);
int tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W;
tensor_strides(top_inp_halo, explicit_nhwc, tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W);
int boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W;
tensor_strides(btm_out_halo, explicit_nhwc, boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W);
int box_stride_N, box_stride_C, box_stride_H, box_stride_W;
tensor_strides(btm_out_tx, explicit_nhwc, box_stride_N, box_stride_C, box_stride_H, box_stride_W);
int bix_stride_N, bix_stride_C, bix_stride_H, bix_stride_W;
tensor_strides(btm_inp_tx, explicit_nhwc, bix_stride_N, bix_stride_C, bix_stride_H, bix_stride_W);
int bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W;
tensor_strides(btm_inp_halo, explicit_nhwc, bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W);
// determine if nhwc
auto is_nhwc = (toh_stride_C == 1) ? true : false;
if (diagnostics) printf("is_nhwc = %s\n",is_nhwc?"true":"false");
// figure out launch parameters
int device;
cudaGetDevice(&device);
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, device);
assert(numSM > 0 && numSM <= prop.multiProcessorCount);
auto current_stream = at::cuda::getCurrentCUDAStream();
const int numThreads = 128;
dim3 block(numThreads,1,1);
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, top_out_halo.scalar_type(), "push_pull_halos_1d_kernel", [&]{
if (diagnostics) printf("size(scalar_t) = %ld\n",sizeof(scalar_t));
scalar_t* toh_p = top_out_halo.data_ptr<scalar_t>();
scalar_t* tox_p = top_out_tx.data_ptr<scalar_t>();
scalar_t* tix_p = top_inp_tx.data_ptr<scalar_t>();
scalar_t* tih_p = top_inp_halo.data_ptr<scalar_t>();
scalar_t* boh_p = btm_out_halo.data_ptr<scalar_t>();
scalar_t* box_p = btm_out_tx.data_ptr<scalar_t>();
scalar_t* bix_p = btm_inp_tx.data_ptr<scalar_t>();
scalar_t* bih_p = btm_inp_halo.data_ptr<scalar_t>();
if (diagnostics) printf("waypoint1\n");
int* top_signal_p = top_signal.data_ptr<int>() + 4;
int* btm_signal_p = btm_signal.data_ptr<int>();
int* top_wait_p = waits.data_ptr<int>();
int* btm_wait_p = waits.data_ptr<int>() + 4;
if (diagnostics) printf("waypoint2\n");
// do int4 vector loads if channel count permits
int elem_size_in_bytes = toh_C * sizeof(scalar_t);
int elem_size_in_int4 = (elem_size_in_bytes / 16);
if (diagnostics) printf("elem_size_in_bytes = %d, elem_size_in_int4 = %d\n",elem_size_in_bytes,elem_size_in_int4);
if (is_nhwc && elem_size_in_int4*16 == elem_size_in_bytes) {
// can do int4 transfers
int divisor = toh_C / elem_size_in_int4;
if (diagnostics) printf("CAN DO INT4 :: divisor = %d\n",divisor);
toh_stride_N /= divisor; toh_stride_H /= divisor; toh_stride_W /= divisor;
tox_stride_N /= divisor; tox_stride_H /= divisor; tox_stride_W /= divisor;
tix_stride_N /= divisor; tix_stride_H /= divisor; tix_stride_W /= divisor;
tih_stride_N /= divisor; tih_stride_H /= divisor; tih_stride_W /= divisor;
boh_stride_N /= divisor; boh_stride_H /= divisor; boh_stride_W /= divisor;
box_stride_N /= divisor; box_stride_H /= divisor; box_stride_W /= divisor;
bix_stride_N /= divisor; bix_stride_H /= divisor; bix_stride_W /= divisor;
bih_stride_N /= divisor; bih_stride_H /= divisor; bih_stride_W /= divisor;
NC /= divisor;
if (diagnostics) {
printf("divisor=%d\n",divisor);
printf("toh_stride :: N=%d, C=%d, H=%d, W=%d\n",toh_stride_N,toh_stride_C,toh_stride_H,toh_stride_W);
printf("tox_stride :: N=%d, C=%d, H=%d, W=%d\n",tox_stride_N,tox_stride_C,tox_stride_H,tox_stride_W);
printf("tix_stride :: N=%d, C=%d, H=%d, W=%d\n",tix_stride_N,tix_stride_C,tix_stride_H,tix_stride_W);
printf("tih_stride :: N=%d, C=%d, H=%d, W=%d\n",tih_stride_N,tih_stride_C,tih_stride_H,tih_stride_W);
printf("boh_stride :: N=%d, C=%d, H=%d, W=%d\n",boh_stride_N,boh_stride_C,boh_stride_H,boh_stride_W);
printf("box_stride :: N=%d, C=%d, H=%d, W=%d\n",box_stride_N,box_stride_C,box_stride_H,box_stride_W);
printf("bix_stride :: N=%d, C=%d, H=%d, W=%d\n",bix_stride_N,bix_stride_C,bix_stride_H,bix_stride_W);
printf("bih_stride :: N=%d, C=%d, H=%d, W=%d\n",bih_stride_N,bih_stride_C,bih_stride_H,bih_stride_W);
printf("NC=%d, NH=%d, NW=%d\n",NC,NH,NW);
}
void *kernelArgs[] = {
(int4**)&toh_p, &toh_stride_C, &toh_stride_H, &toh_stride_W,
(int4**)&tox_p, &tox_stride_C, &tox_stride_H, &tox_stride_W,
(int4**)&tix_p, &tix_stride_C, &tix_stride_H, &tix_stride_W,
(int4**)&tih_p, &tih_stride_C, &tih_stride_H, &tih_stride_W,
(int4**)&boh_p, &boh_stride_C, &boh_stride_H, &boh_stride_W,
(int4**)&box_p, &box_stride_C, &box_stride_H, &box_stride_W,
(int4**)&bix_p, &bix_stride_C, &bix_stride_H, &bix_stride_W,
(int4**)&bih_p, &bih_stride_C, &bih_stride_H, &bih_stride_W,
&NC, &NH, &NW,
&top_signal_p, &btm_signal_p, &top_wait_p, &btm_wait_p
};
int numBlocksPerSm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<int4,true>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true>, grid, block, kernelArgs, 0, current_stream);
} else {
// cannot do int4 transfers
if (diagnostics) printf("CAN NOT DO INT4\n");
void *kernelArgs[] = {
&toh_p, &toh_stride_C, &toh_stride_H, &toh_stride_W,
&tox_p, &tox_stride_C, &tox_stride_H, &tox_stride_W,
&tix_p, &tix_stride_C, &tix_stride_H, &tix_stride_W,
&tih_p, &tih_stride_C, &tih_stride_H, &tih_stride_W,
&boh_p, &boh_stride_C, &boh_stride_H, &boh_stride_W,
&box_p, &box_stride_C, &box_stride_H, &box_stride_W,
&bix_p, &bix_stride_C, &bix_stride_H, &bix_stride_W,
&bih_p, &bih_stride_C, &bih_stride_H, &bih_stride_W,
&NC, &NH, &NW,
&top_signal_p, &btm_signal_p, &top_wait_p, &btm_wait_p
};
int numBlocksPerSm;
if (is_nhwc) {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,true>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true>, grid, block, kernelArgs, 0, current_stream);
} else {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,false>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false>, grid, block, kernelArgs, 0, current_stream);
}
}
} );
}
} } }
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <torch/extension.h>
#ifndef _peer_memory_h_
#define _peer_memory_h_
namespace apex { namespace contrib { namespace peer_memory {
int64_t allocate_raw(int64_t size);
void free_raw(int64_t raw);
void zero(int64_t raw, int64_t size);
at::Tensor get_raw_ipc_address(int64_t raw);
std::vector<int64_t> get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int64_t raw);
at::Tensor blob_view_half(int64_t raw, std::vector<int64_t> shape, bool channels_last);
at::Tensor blob_view_float(int64_t raw, std::vector<int64_t> shape, bool channels_last);
at::Tensor blob_view_int(int64_t raw, std::vector<int64_t> shape, bool channels_last);
void push_pull_halos_1d(
bool diagnostics,
bool explicit_nhwc,
int numSM, // number of SMs to use
at::Tensor top_out_halo, // top output halo in sender device memory
at::Tensor top_out_tx, // top output transfer buffer in sender peer pool memory
at::Tensor top_inp_tx, // top input transfer buffer in top neighbor peer pool memory
at::Tensor top_inp_halo, // top input halo in receiver device memory
at::Tensor btm_out_halo, // btm output halo in sender device memory
at::Tensor btm_out_tx, // btm output transfer buffer in sender peer pool memory
at::Tensor btm_inp_tx, // btm input transfer buffer in btm neighbor peer pool memory
at::Tensor btm_inp_halo, // btm input halo in receiver device memory
at::Tensor top_signal, // top input signal in receiver device memory
at::Tensor btm_signal, // btm input signal in receiver device memory
at::Tensor waits // top and btm signals for this rank
);
} } }
#endif
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#if !defined(NEW_GENERATOR_PATH) #ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h> #include <ATen/CUDAGeneratorImpl.h>
#else #else
#include <ATen/cuda/CUDAGeneratorImpl.h> #include <ATen/cuda/CUDAGeneratorImpl.h>
......
...@@ -32,16 +32,18 @@ import fmhalib as mha ...@@ -32,16 +32,18 @@ import fmhalib as mha
class FMHAFun(torch.autograd.Function): class FMHAFun(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, cu_seqlens, p_dropout, max_s, is_training): def forward(ctx, qkv, cu_seqlens, p_dropout, max_s, is_training, zero_tensors):
batch_size = cu_seqlens.numel() - 1 batch_size = cu_seqlens.numel() - 1
if batch_size < 4: if batch_size < 4:
context, S_dmask = mha.fwd_nl(qkv, cu_seqlens, p_dropout, max_s, is_training, None) max_s = 512
context, S_dmask = mha.fwd_nl(qkv, cu_seqlens, p_dropout, max_s, is_training, True, zero_tensors, None)
else: else:
context, S_dmask = mha.fwd(qkv, cu_seqlens, p_dropout, max_s, is_training, None) context, S_dmask = mha.fwd(qkv, cu_seqlens, p_dropout, max_s, is_training, False, zero_tensors, None)
ctx.save_for_backward(qkv, S_dmask) ctx.save_for_backward(qkv, S_dmask)
ctx.cu_seqlens = cu_seqlens ctx.cu_seqlens = cu_seqlens
ctx.p_dropout = p_dropout ctx.p_dropout = p_dropout
ctx.max_s = max_s ctx.max_s = max_s
ctx.zero_tensors = zero_tensors
return context return context
@staticmethod @staticmethod
...@@ -49,11 +51,11 @@ class FMHAFun(torch.autograd.Function): ...@@ -49,11 +51,11 @@ class FMHAFun(torch.autograd.Function):
qkv, S_dmask = ctx.saved_tensors qkv, S_dmask = ctx.saved_tensors
batch_size = ctx.cu_seqlens.numel() - 1 batch_size = ctx.cu_seqlens.numel() - 1
if batch_size < 4: if batch_size < 4:
dqkv, dp, _ = mha.bwd_nl(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s) dqkv, dp, _ = mha.bwd_nl(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s, ctx.zero_tensors)
else: else:
dqkv, dp = mha.bwd(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s) dqkv, dp = mha.bwd(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s, ctx.zero_tensors)
return dqkv, None, None, None, None, None, None return dqkv, None, None, None, None, None
class FMHA(torch.nn.Module): class FMHA(torch.nn.Module):
...@@ -67,8 +69,8 @@ class FMHA(torch.nn.Module): ...@@ -67,8 +69,8 @@ class FMHA(torch.nn.Module):
self.d = self.hidden_size // self.h self.d = self.hidden_size // self.h
assert self.d * self.h == self.hidden_size, "Invalid hidden size/num_heads" assert self.d * self.h == self.hidden_size, "Invalid hidden size/num_heads"
def forward(self, qkv, cu_seqlens, max_s, is_training=True): def forward(self, qkv, cu_seqlens, max_s, is_training=True, zero_tensors=False):
ctx = FMHAFun.apply(qkv.view(-1, 3, self.h, self.d), cu_seqlens, self.p_dropout, max_s, is_training) ctx = FMHAFun.apply(qkv.view(-1, 3, self.h, self.d), cu_seqlens, self.p_dropout, max_s, is_training, zero_tensors)
return ctx.view(-1, self.hidden_size) return ctx.view(-1, self.hidden_size)
try:
import torch
import focal_loss_cuda
from .focal_loss import focal_loss
del torch
del focal_loss_cuda
del focal_loss
except ImportError as err:
print("apex was installed without --focal_loss flag, apex.contrib.focal_loss is not available")
import torch
import focal_loss_cuda
class FocalLoss(torch.autograd.Function):
@staticmethod
def forward(
ctx,
cls_output,
cls_targets_at_level,
num_positives_sum,
num_real_classes,
alpha,
gamma,
label_smoothing=0.0,
):
loss, partial_grad = focal_loss_cuda.forward(
cls_output,
cls_targets_at_level,
num_positives_sum,
num_real_classes,
alpha,
gamma,
label_smoothing,
)
ctx.save_for_backward(partial_grad, num_positives_sum)
return loss
@staticmethod
def backward(ctx, grad_loss):
partial_grad, num_positives_sum = ctx.saved_tensors
# The backward kernel is actually in-place to save memory space,
# partial_grad and grad_input are the same tensor.
grad_input = focal_loss_cuda.backward(grad_loss, partial_grad, num_positives_sum)
return grad_input, None, None, None, None, None, None
def focal_loss(
cls_output: torch.Tensor,
cls_targets_at_level: torch.Tensor,
num_positive_sum: torch.Tensor,
num_real_classes: int,
alpha: float,
gamma: float,
label_smoothing: float = 0.0,
) -> torch.Tensor:
"""Fused focal loss function."""
return FocalLoss.apply(
cls_output,
cls_targets_at_level,
num_positive_sum,
num_real_classes,
alpha,
gamma,
label_smoothing,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment