Commit c2b62b7f authored by JR_ZZU's avatar JR_ZZU 🌴
Browse files

delete origin files

parent 2a4864d5
#pragma once
#include <iostream>
#include <vector>
#include <cuda.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
//#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
//#include "cutlass/cutlass.h"
//#include "cutlass/gemm/gemm.h"
//#include "cutlass/gemm/wmma_gemm_traits.h"
// symbol to be automatically resolved by PyTorch libs
/*
rocblas_datatype a_type = rocblas_datatype_f16_r; // OK
rocblas_datatype b_type = rocblas_datatype_f16_r; // OK
rocblas_datatype c_type = rocblas_datatype_f16_r; // OK
rocblas_datatype d_type = rocblas_datatype_f16_r;
rocblas_datatype compute_type = rocblas_datatype_f32_r;
rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
int32_t solution_index = 0;
rocblas_int flags = 0;
*/
namespace {
cublasOperation_t convertTransToCublasOperation(char trans) {
if (trans == 't')
return CUBLAS_OP_T;
else if (trans == 'n')
return CUBLAS_OP_N;
else if (trans == 'c')
return CUBLAS_OP_C;
else {
AT_ERROR("trans must be one of: t, n, c");
return CUBLAS_OP_T;
}
}
void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) {
cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb);
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
float fAlpha = alpha;
float fBeta = beta;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle,
opa, opb, (int)m, (int)n, (int)k,
(void*)&fAlpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA,
b, rocblas_datatype_f16_r /*b_type*/, (int)ldb, strideB,
(void*)&fBeta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC,
d, rocblas_datatype_f16_r /*d_type*/, int(ldd), strideD,
(int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags));
}
void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_int flags) {
auto stream = c10::cuda::getCurrentCUDAStream();
if ( (transa == 't') && (transb == 'n') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
} else if ( (transa == 'n') && (transb == 'n') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
} else if ( (transa == 'n') && (transb == 't') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
} else {
AT_ASSERTM(false, "TransA and TransB are invalid");
}
}
void adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k,
int64_t *lda, int64_t *ldb, int64_t *ldc) {
int transa_ = ((transa == 't') || (transa == 'T'));
int transb_ = ((transb == 't') || (transb == 'T'));
// Note: leading dimensions generally are checked that they are > 0 and at
// least as big the result requires (even if the value won't be used).
if (n <= 1)
*ldc = std::max<int64_t>(m, 1);
if (transa_) {
if (m <= 1)
*lda = std::max<int64_t>(k, 1);
} else {
if (k <= 1)
*lda = std::max<int64_t>(m, 1);
}
if (transb_) {
if (k <= 1)
*ldb = std::max<int64_t>(n, 1);
} else {
if (n <= 1)
*ldb = std::max<int64_t>(k, 1);
}
}
void HgemmStridedBatched(char transa, char transb, long m,
long n, long k, float alpha, const half *a, long lda,
long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC,
half *d, long ldd, long strideD, long batchCount) {
if ((m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) ||
(ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX))
{
AT_ERROR("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, "
"batchCount"
"with the bound [val] <= %d",
INT_MAX);
}
adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
// gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA,
// b, ldb, strideB, beta, c, ldc, strideC, batchCount);
gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA,
b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, 0 /*flags*/);
}
} // namespace
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nccl_p2p_cuda.cuh"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("get_unique_nccl_id", &apex::contrib::nccl_p2p::get_unique_nccl_id, "get_unique_nccl_id");
m.def("init_nccl_comm", &apex::contrib::nccl_p2p::init_nccl_comm, "init_nccl_comm");
m.def("left_right_halo_exchange_inplace", &apex::contrib::nccl_p2p::left_right_halo_exchange_inplace, "left_right_halo_exchange_inplace");
m.def("left_right_halo_exchange", &apex::contrib::nccl_p2p::left_right_halo_exchange, "left_right_halo_exchange");
m.def("add_delay", &apex::contrib::nccl_p2p::add_delay, "add_delay");
}
#include <torch/extension.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <ATen/cuda/CUDAContext.h>
#include <list>
#include <cstdio>
#include <ctime>
#include <cassert>
#ifdef __HIP_PLATFORM_HCC__
#include "rccl/rccl.h"
#else
#include "nccl.h"
#endif
/*
* This file implements a crude but effective mechanism for copying data between tenors owned by different ranks
* on the same machine using cudaMemcpyAsync peer-to-peer transfers.
*/
namespace {
__global__ void AddDelay_kernel(const int delay, int* counter) {
if (blockIdx.x == 0 && threadIdx.x == 0) {
// waste time while doing something compiler can't predict, thus preventing it from optimizing away this code.
int new_counter = 0;
double elapsed = 0;
clock_t start = clock();
do {
clock_t now = clock();
elapsed = (double)(now - start)*1e9 / CLOCKS_PER_SEC;
++new_counter;
} while (elapsed < (double)delay);
*counter = new_counter;
}
}
class NcclCommWrapper
{
private:
ncclComm_t comm;
int rank, world_size;
ncclDataType_t get_nccl_type(at::Tensor input)
{
switch (input.scalar_type())
{
case at::ScalarType::Half:
return ncclFloat16;
case at::ScalarType::Float:
return ncclFloat32;
case at::ScalarType::Double:
return ncclFloat64;
case at::ScalarType::Byte:
return ncclUint8;
case at::ScalarType::Char:
return ncclInt8;
case at::ScalarType::Int:
return ncclInt32;
case at::ScalarType::Long:
return ncclInt64;
case at::ScalarType::BFloat16:
return ncclBfloat16;
default:
assert(false);
}
}
public:
NcclCommWrapper()
{
memset(&comm, 0, sizeof(ncclComm_t));
rank = 0;
world_size = 0;
}
NcclCommWrapper(ncclUniqueId id, int my_rank, int num_ranks)
{
ncclCommInitRank(&comm, num_ranks, id, my_rank);
rank = my_rank;
world_size = num_ranks;
}
~NcclCommWrapper()
{
printf("ncclCommDestroy()\n");
ncclCommDestroy(comm);
}
void left_right_halo_exchange_inplace(int left_rank, int right_rank, at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo)
{
auto stream = at::cuda::getCurrentCUDAStream();
ncclGroupStart();
ncclDataType_t ncclType = get_nccl_type(left_output_halo);
bool left_zero = (left_rank < 0);
bool right_zero = (right_rank < 0);
size_t left_n = torch::numel(left_output_halo);
size_t right_n = torch::numel(right_output_halo);
assert(left_n > 0 && left_n == right_n);
if (left_zero) {
left_input_halo.zero_();
} else {
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, left_output_halo.scalar_type(), "left_halo_exch", [&]() {
// send left (to my_rank - 1)
ncclSend(left_output_halo.data_ptr<scalar_t>(), left_n, ncclType, left_rank, comm, stream);
// receive left (from my_rank - 1)
ncclRecv(left_input_halo.data_ptr<scalar_t>(), right_n, ncclType, left_rank, comm, stream);
});
}
if (right_zero) {
right_input_halo.zero_();
} else {
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, right_output_halo.scalar_type(), "right_halo_exch", [&]() {
// send right (to my_rank + 1 )
ncclSend(right_output_halo.data_ptr<scalar_t>(), right_n, ncclType, right_rank, comm, stream);
// receive right (from my_rank + 1)
ncclRecv(right_input_halo.data_ptr<scalar_t>(), left_n, ncclType, right_rank, comm, stream);
});
}
ncclGroupEnd();
}
std::vector<at::Tensor> left_right_halo_exchange(int left_rank, int right_rank, at::Tensor left_output_halo, at::Tensor right_output_halo)
{
// after halo exchange:
// left_output_halo of rank+1 ends up in right_input_halo of rank
// right_output_halo of rank-1 ends up in left_input_halo of rank
auto right_input_halo = torch::empty_like(left_output_halo);
auto left_input_halo = torch::empty_like(right_output_halo);
left_right_halo_exchange_inplace(left_rank, right_rank, left_output_halo, right_output_halo, left_input_halo, right_input_halo);
return {left_input_halo, right_input_halo};
}
};
class ManagedObjects
{
public:
ManagedObjects()
{
}
~ManagedObjects()
{
for (auto it = _nccl_comms.begin(); it != _nccl_comms.end(); ++it)
{
delete *it;
}
}
int add_comm(NcclCommWrapper* comm)
{
int handle = _nccl_comms.size();
_nccl_comms.push_back(comm);
return handle;
}
NcclCommWrapper& get_comm(int handle)
{
assert(handle >= 0 && handle < _nccl_comms.size());
return *_nccl_comms[handle];
}
private:
std::vector<NcclCommWrapper*> _nccl_comms;
};
class ManagedObjects mo;
} // end anonymous namespace
namespace apex { namespace contrib { namespace nccl_p2p {
at::Tensor get_unique_nccl_id(int n)
{
ncclUniqueId id;
ncclGetUniqueId(&id);
auto id_tensor = torch::empty({n,(int)sizeof(ncclUniqueId)}, torch::dtype(torch::kUInt8).device(torch::kCPU).requires_grad(false));
auto id_ptr = id_tensor.data_ptr<uint8_t>();
size_t offset = 0;
for (int i = 0; i < n; ++i)
{
ncclUniqueId id;
ncclGetUniqueId(&id);
memcpy(id_ptr+offset, &id, sizeof(ncclUniqueId));
offset += sizeof(ncclUniqueId);
}
return id_tensor;
}
int init_nccl_comm(at::Tensor unique_nccl_id, int my_rank, int num_ranks)
{
ncclUniqueId id;
auto unique_nccl_id_ptr = unique_nccl_id.data_ptr<uint8_t>();
memcpy(&id, unique_nccl_id_ptr, sizeof(ncclUniqueId));
NcclCommWrapper* comm = new NcclCommWrapper(id, my_rank, num_ranks);
int handle = mo.add_comm(comm);
comm = 0L;
return handle;
}
void left_right_halo_exchange_inplace(int handle, int left_rank, int right_rank, at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo)
{
class NcclCommWrapper& communicator = mo.get_comm(handle);
return communicator.left_right_halo_exchange_inplace(left_rank, right_rank, left_output_halo, right_output_halo, left_input_halo, right_input_halo);
}
std::vector<at::Tensor> left_right_halo_exchange(int handle, int left_rank, int right_rank, at::Tensor left_output_halo, at::Tensor right_output_halo)
{
class NcclCommWrapper& communicator = mo.get_comm(handle);
return communicator.left_right_halo_exchange(left_rank, right_rank, left_output_halo, right_output_halo);
}
void add_delay(int delay)
{
auto stream = at::cuda::getCurrentCUDAStream();
auto t = torch::empty({1}, torch::dtype(torch::kInt32).device(torch::kCUDA));
AddDelay_kernel<<<1,1,0,stream>>>(delay, t.data_ptr<int>());
}
}}}
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <torch/extension.h>
#ifndef _nccl_p2p_h_
#define _nccl_p2p_h_
namespace apex { namespace contrib { namespace nccl_p2p {
at::Tensor get_unique_nccl_id(int n);
int init_nccl_comm(
at::Tensor unique_nccl_id,
int my_rank,
int num_ranks
);
void left_right_halo_exchange_inplace(
int handle,
int left_rank,
int right_rank,
at::Tensor left_output_halo,
at::Tensor right_output_halo,
at::Tensor left_input_halo,
at::Tensor right_input_halo);
std::vector<at::Tensor> left_right_halo_exchange(
int handle,
int left_rank,
int right_rank,
at::Tensor left_output_halo,
at::Tensor right_output_halo);
void add_delay(int delay);
}}}
#endif
#include <torch/extension.h>
// CUDA forward declaration
void fused_strided_check_finite(at::Tensor & overflow_flag, at::Tensor & p_copy, int stride, int clear_overflow_first);
void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_reversible_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_maybe_adam_undo_cuda(at::Tensor & overflow_flag, at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void maybe_cast_cuda(at::Tensor & overflow_flag, at::Tensor & p_in, at::Tensor & p_out);
void maybe_cast_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
// C++ interface
void strided_check_finite(
at::Tensor& overflow_flag,
at::Tensor& p_copy,
int stride,
int clear_overflow_first
) {
CHECK_INPUT(p_copy);
fused_strided_check_finite(overflow_flag, p_copy, stride, clear_overflow_first);
}
void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {
CHECK_INPUT(p);
if (p_copy.numel() > 0) CHECK_INPUT(p_copy);
CHECK_INPUT(m);
CHECK_INPUT(v);
CHECK_INPUT(g);
int64_t num_elem = p.numel();
AT_ASSERTM(m.numel() == num_elem, "number of elements in m and p tensors should be equal");
AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal");
AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal");
AT_ASSERTM(p_copy.numel() == num_elem || p_copy.numel() == 0, "number of elements in p_copy and p tensors should be equal, or p_copy should be empty");
fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);
}
void reversible_adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {
CHECK_INPUT(p);
if (p_copy.numel() > 0) CHECK_INPUT(p_copy);
CHECK_INPUT(m);
CHECK_INPUT(v);
CHECK_INPUT(g);
int64_t num_elem = p.numel();
AT_ASSERTM(m.numel() == num_elem, "number of elements in m and p tensors should be equal");
AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal");
AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal");
AT_ASSERTM(p_copy.numel() == num_elem || p_copy.numel() == 0, "number of elements in p_copy and p tensors should be equal, or p_copy should be empty");
fused_reversible_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);
}
void maybe_adam_undo(at::Tensor & overflow_flag, at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {
CHECK_INPUT(p);
CHECK_INPUT(m);
CHECK_INPUT(v);
CHECK_INPUT(g);
int64_t num_elem = p.numel();
AT_ASSERTM(m.numel() == num_elem, "number of elements in m and p tensors should be equal");
AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal");
AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal");
fused_maybe_adam_undo_cuda(overflow_flag, p, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);
}
void maybe_cast(at::Tensor & overflow_flag, at::Tensor & p_in, at::Tensor & p_out) {
CHECK_INPUT(p_in);
CHECK_INPUT(p_out);
int64_t num_elem = p_in.numel();
AT_ASSERTM(p_out.numel() == num_elem, "number of elements in p_in and p_out should be equal");
maybe_cast_cuda(overflow_flag, p_in, p_out);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("strided_check_finite", &strided_check_finite, "Strided finite check.");
m.def("adam", &adam, "Adam optimized CUDA implementation.");
m.def("reversible_adam", &reversible_adam, "Reversible Adam optimized CUDA implementation.");
m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.");
m.def("maybe_adam_undo", &maybe_adam_undo, "Undo function for Adam optimized CUDA implementation.");
m.def("maybe_cast", &maybe_cast, "Unpack byte tensor containing e5m2 floats.");
m.def("maybe_cast_mt", &maybe_cast_cuda_mt, "Unpack byte tensor containing e5m2 floats.");
}
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <cmath>
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/detail/IndexUtils.cuh"
#include "ATen/TensorUtils.h"
// #include "ATen/Type.h"
#include "ATen/AccumulateType.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}
template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
#include "type_shim.h"
typedef enum{
ADAM_MODE_0 =0, // eps under square root
ADAM_MODE_1 =1 // eps outside square root
} adamMode_t;
template <typename T, typename GRAD_T>
__global__ void adam_cuda_kernel(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T * __restrict__ g,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay)
{
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
for (int j = i; j < tsize; j+=totThreads) {
T scaled_grad = g[j]/grad_scale;
m[j] = b1*m[j] + (1-b1)*scaled_grad;
v[j] = b2*v[j] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
float update = (m[j]/denom) + (decay*p[j]);
p[j] = p[j] - (step_size*update);
if (p_copy != NULL) p_copy[j] = (GRAD_T) p[j];
}
}
template <int DEPTH, typename T, typename GRAD_T>
struct AdamFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<DEPTH>& tl,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
adamMode_t mode,
const float decay)
{
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T* p = (T *)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T *)tl.addresses[1][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T *)tl.addresses[2][tensor_loc];
v += chunk_idx*chunk_size;
GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];
g += chunk_idx*chunk_size;
GRAD_T* p_copy = NULL;
if (DEPTH == 5) {
p_copy = (GRAD_T *)tl.addresses[4][tensor_loc];
p_copy += chunk_idx*chunk_size;
}
n -= chunk_idx*chunk_size;
T incoming_p[ILP];
T incoming_m[ILP];
T incoming_v[ILP];
T incoming_g[ILP];
// to make things simple, we put aligned case in a different code path
if(n % ILP == 0 &&
chunk_size % ILP == 0 &&
is_aligned(p) &&
is_aligned(m) &&
is_aligned(v) &&
is_aligned(g) &&
is_aligned(p_copy))
{
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
{
// load
GRAD_T tmp_g[ILP];
load_store(incoming_p, p, 0, i_start);
load_store(incoming_m, m, 0, i_start);
load_store(incoming_v, v, 0, i_start);
load_store(tmp_g, g, 0, i_start);
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
incoming_g[ii] = static_cast<T>(tmp_g[ii]);
T scaled_grad = incoming_g[ii]/grad_scale;
incoming_m[ii] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
incoming_v[ii] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(incoming_v[ii] + eps);
else // Mode 1
denom = sqrtf(incoming_v[ii]) + eps;
float update = (incoming_m[ii]/denom) + (decay*incoming_p[ii]);
incoming_p[ii] = incoming_p[ii] - (step_size*update);
if (DEPTH == 5) tmp_g[ii] = static_cast<GRAD_T>(incoming_p[ii]);
}
load_store(p, incoming_p, i_start, 0);
load_store(m, incoming_m, i_start, 0);
load_store(v, incoming_v, i_start, 0);
if (DEPTH == 5) load_store(p_copy, tmp_g, i_start, 0);
}
}
else
{
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
incoming_p[ii] = 0;
incoming_m[ii] = 0;
incoming_v[ii] = 0;
incoming_g[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if (i < n && i < chunk_size) {
incoming_p[ii] = p[i];
incoming_m[ii] = m[i];
incoming_v[ii] = v[i];
incoming_g[ii] = static_cast<T>(g[i]);
}
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = i_start + threadIdx.x + ii*blockDim.x;
if(j < n && j < chunk_size) {
T scaled_grad = incoming_g[ii]/grad_scale;
m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
float update = (m[j]/denom) + (decay*incoming_p[ii]);
p[j] = incoming_p[ii] - (step_size*update);
if (DEPTH == 5) p_copy[j] = (GRAD_T) p[j];
}
}
}
}
}
};
void fused_adam_cuda(
at::Tensor & p,
at::Tensor & p_copy,
at::Tensor & m,
at::Tensor & v,
at::Tensor & g,
float lr,
float beta1,
float beta2,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay)
{
// using namespace at;
//Get tensor size
int tsize = p.numel();
//Determine #threads and #blocks
const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.scalar_type() == at::ScalarType::Half || g.scalar_type() == at::ScalarType::BFloat16) {
//all other values should be fp32 for half gradients
AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<accscalar_t>(),
p_copy.numel() ? p_copy.DATA_PTR<scalar_t_0>() : NULL,
m.DATA_PTR<accscalar_t>(),
v.DATA_PTR<accscalar_t>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
} else {
using namespace at;
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
adam_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<scalar_t_0>(),
NULL, //don't output p_copy for fp32, it's wasted write
m.DATA_PTR<scalar_t_0>(),
v.DATA_PTR<scalar_t_0>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
}
C10_CUDA_CHECK(cudaGetLastError());
}
void fused_adam_cuda_mt(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy
float lr,
float beta1,
float beta2,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay) {
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
size_t tl_sz = tensor_lists.size();
AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5");
if (tensor_lists[3][0].scalar_type() == at::ScalarType::Half || tensor_lists[3][0].scalar_type() == at::ScalarType::BFloat16) {
//alher values should be fp32 for half gradients
AT_ASSERTM(tensor_lists[0][0].scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dich is done on the gradient type
if (tl_sz == 5) {
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<5, accscalar_t, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
} else {
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, accscalar_t, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
}
} else {
if (tl_sz == 5) {
DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<5, scalar_t_0, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
} else {
DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, scalar_t_0, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
}
}
C10_CUDA_CHECK(cudaGetLastError());
}
template <typename FROM_T, typename TO_T>
__device__ void convert(const FROM_T vi, TO_T& vo)
{
vo = static_cast<TO_T>(vi);
}
template <>
__device__ void convert(const float vi, uint8_t& vo)
{
union S
{
float as_float;
int as_int;
};
S s;
s.as_float = vi;
s.as_int = s.as_int & 0xFF800000;
union T
{
at::Half as_half;
uint8_t as_byte[2];
};
T t;
t.as_half = static_cast<at::Half>(vi + s.as_float / 8.0f);
vo = t.as_byte[1];
}
template <>
__device__ void convert(const uint8_t vi, float& vo)
{
union T
{
at::Half as_half;
uint8_t as_byte[2];
};
T t;
t.as_byte[0] = 0;
t.as_byte[1] = vi;
vo = static_cast<float>(t.as_half);
}
template <>
__device__ void convert(const at::Half vi, uint8_t& vo)
{
union S
{
float as_float;
int as_int;
};
S s;
s.as_float = static_cast<float>(vi);
s.as_int = s.as_int & 0xFF800000;
union T
{
at::Half as_half;
uint8_t as_byte[2];
};
T t;
t.as_half = static_cast<at::Half>(vi + s.as_float / 8.0f);
vo = t.as_byte[1];
}
template <>
__device__ void convert(const uint8_t vi, at::Half& vo)
{
union T
{
at::Half as_half;
uint8_t as_byte[2];
};
T t;
t.as_byte[0] = 0;
t.as_byte[1] = vi;
vo = t.as_half;
}
template <typename GRAD_T>
__global__ void strided_check_finite_cuda_kernel(
volatile int* noop_gmem,
GRAD_T* __restrict__ p_copy,
const size_t tsize,
int stride,
int clear_overflow_first)
{
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock) * stride;
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock*stride;
if (clear_overflow_first) {
if (i == 0) {
*noop_gmem = 0;
}
__syncthreads();
}
for (int j = i; j < tsize; j+=totThreads) {
GRAD_T pi = p_copy[j];
if (!isfinite(pi)) {
*noop_gmem = 1;
}
}
}
template <>
__global__ void strided_check_finite_cuda_kernel(
volatile int* noop_gmem,
uint8_t* __restrict__ p_copy,
const size_t tsize,
int stride,
int clear_overflow_first)
{
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock) * stride;
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock*stride;
if (clear_overflow_first) {
if (i == 0) {
*noop_gmem = 0;
}
__syncthreads();
}
for (int j = i; j < tsize; j+=totThreads) {
at::Half pi;
convert(p_copy[j], pi);
if (!isfinite(pi)) {
*noop_gmem = 1;
}
}
}
template <typename FROM_T, typename TO_T>
__global__ void maybe_cast_kernel(
volatile int* overflow_flag,
const FROM_T* p_in,
TO_T* p_out,
const size_t tsize)
{
if (overflow_flag && *overflow_flag != 0) return;
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
FROM_T pi[ILP];
TO_T po[ILP];
for(int j_start = 0; j_start < tsize; j_start+=totThreads*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
pi[ii] = 0;
int j = j_start + i + totThreads*ii;
if (j < tsize) {
pi[ii] = p_in[j];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
convert(pi[ii], po[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = j_start + i + totThreads*ii;
if (j < tsize) {
p_out[j] = po[ii];
}
}
}
}
template <typename T, typename GRAD_T, typename REDU_T>
__global__ void reversible_adam_cuda_kernel(
T* __restrict__ p,
REDU_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T * __restrict__ g,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay)
{
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
T mi[ILP];
T vi[ILP];
T pi[ILP];
T gi[ILP];
bool overflow = false;
for(int j_start = 0; j_start < tsize; j_start+=totThreads*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
mi[ii] = T(0);
vi[ii] = T(0);
pi[ii] = T(0);
gi[ii] = GRAD_T(0);
int j = j_start + i + totThreads*ii;
if (j < tsize) {
pi[ii] = p[j];
mi[ii] = m[j];
vi[ii] = v[j];
gi[ii] = static_cast<T>(g[j]);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
T scaled_grad = gi[ii]/grad_scale;
if (isfinite(scaled_grad)) {
mi[ii] = b1*mi[ii] + (1-b1)*scaled_grad;
vi[ii] = b2*vi[ii] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(vi[ii] + eps);
else // Mode 1
denom = sqrtf(vi[ii]) + eps;
float update = (mi[ii]/denom) + (decay*pi[ii]);
pi[ii] = pi[ii] - (step_size*update);
} else {
overflow = true;
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = j_start + i + totThreads*ii;
if (j < tsize) {
m[j] = mi[ii];
v[j] = vi[ii];
p[j] = pi[ii];
if (p_copy != NULL) {
convert(pi[ii], p_copy[j]);
}
}
}
}
if (p_copy != NULL) {
__syncthreads();
if (overflow) {
convert(float(INFINITY), p_copy[0]);
}
}
}
template <typename T, typename GRAD_T>
__global__ void maybe_adam_undo_cuda_kernel(
volatile int* overflow_flag,
T* __restrict__ p,
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T * __restrict__ g,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay)
{
// NB! Skip undo kernel when overflow flag is NOT set
if (overflow_flag && *overflow_flag == 0) return;
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
T mi[ILP];
T vi[ILP];
T pi[ILP];
T gi[ILP];
for(int j_start = 0; j_start < tsize; j_start+=totThreads*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
mi[ii] = T(0);
vi[ii] = T(0);
pi[ii] = T(0);
gi[ii] = GRAD_T(0);
int j = j_start + i*ILP;
if (j < tsize) {
pi[ii] = p[j];
mi[ii] = m[j];
vi[ii] = v[j];
gi[ii] = static_cast<T>(g[j]);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
T scaled_grad = gi[ii]/grad_scale;
if (isfinite(scaled_grad)) {
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(vi[ii] + eps);
else // Mode 1
denom = sqrtf(vi[ii]) + eps;
pi[ii] = (pi[ii] + step_size*(mi[ii]/denom)) / (1.0f - step_size*decay);
mi[ii] = (mi[ii] - (1-b1)*scaled_grad) / b1;
vi[ii] = (vi[ii] - (1-b2)*scaled_grad*scaled_grad) / b2;
// Make sure round off errors don't create (small) negative value.
// This can happen if we have to revert the very first step.
vi[ii] = vi[ii] >= 0.0f ? vi[ii] : 0.0f;
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = j_start + i*ILP;
if (j < tsize) {
m[j] = mi[ii];
v[j] = vi[ii];
p[j] = pi[ii];
}
}
}
}
template <int DEPTH, typename FROM_T, typename TO_T>
struct MaybeCastFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* overflow_flag,
TensorListMetadata<DEPTH>& tl)
{
if (overflow_flag && *overflow_flag != 0) return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
FROM_T* p_in = (FROM_T *)tl.addresses[0][tensor_loc];
p_in += chunk_idx*chunk_size;
TO_T* p_out = (TO_T *)tl.addresses[1][tensor_loc];
p_out += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
int dim = chunk_size < n ? chunk_size : n;
FROM_T pi[ILP];
TO_T po[ILP];
for(int j_start = 0; j_start < dim; j_start+=blockDim.x*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
pi[ii] = FROM_T(0);
int j = j_start + threadIdx.x + ii*blockDim.x;
if (j < dim) {
pi[ii] = p_in[j];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
convert(pi[ii], po[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = j_start + threadIdx.x + ii*blockDim.x;
if (j < dim) {
p_out[j] = po[ii];
}
}
}
}
};
void fused_strided_check_finite(
at::Tensor & overflow_flag,
at::Tensor & p_copy,
int stride,
int clear_overflow_first)
{
//Get tensor size
int tsize = p_copy.numel();
int niter = (tsize + stride - 1) / stride;
//Determine #threads and #blocks
const int threadsPerBlock = 512;
//In order to avoid race condition, blocks must be 1 when clear_overflow_first flag is set.
const dim3 blocks(clear_overflow_first ? 1 : (niter+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p_copy), "parameter tensor is too large to be indexed with int32");
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_HALF_AND_BYTE(p_copy.scalar_type(), 0, "check_finite_cuda_kernel",
strided_check_finite_cuda_kernel<scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
overflow_flag.DATA_PTR<int>(),
p_copy.DATA_PTR<scalar_t_0>(),
tsize,
stride,
clear_overflow_first);
);
C10_CUDA_CHECK(cudaGetLastError());
}
void fused_reversible_adam_cuda(
at::Tensor & p,
at::Tensor & p_copy,
at::Tensor & m,
at::Tensor & v,
at::Tensor & g,
float lr,
float beta1,
float beta2,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay)
{
// using namespace at;
//Get tensor size
int tsize = p.numel();
//Determine #threads and #blocks
const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.scalar_type() == at::ScalarType::Half || g.scalar_type() == at::ScalarType::BFloat16) {
//all other values should be fp32 for half gradients
AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
if (p_copy.numel() == 0 || p_copy.scalar_type() == g.scalar_type()) {
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
reversible_adam_cuda_kernel<accscalar_t, scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<accscalar_t>(),
p_copy.numel() ? p_copy.DATA_PTR<scalar_t_0>() : NULL,
m.DATA_PTR<accscalar_t>(),
v.DATA_PTR<accscalar_t>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
} else {
AT_ASSERTM(p_copy.scalar_type() == at::ScalarType::Byte, "expected parameter to be of byte type");
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(g.scalar_type(), 0, "adam_cuda_e5m2_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
reversible_adam_cuda_kernel<accscalar_t, scalar_t_0, uint8_t><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<accscalar_t>(),
p_copy.DATA_PTR<uint8_t>(),
m.DATA_PTR<accscalar_t>(),
v.DATA_PTR<accscalar_t>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
}
} else {
using namespace at;
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
reversible_adam_cuda_kernel<scalar_t_0, scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<scalar_t_0>(),
NULL, //don't output p_copy for fp32, it's wasted write
m.DATA_PTR<scalar_t_0>(),
v.DATA_PTR<scalar_t_0>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
}
C10_CUDA_CHECK(cudaGetLastError());
}
void maybe_cast_cuda(
at::Tensor & overflow_flag,
at::Tensor & p_in,
at::Tensor & p_out)
{
//Get tensor size
int tsize = p_in.numel();
AT_ASSERTM(tsize == p_out.numel(), "p_in.numel() must equal p_out.numel()");
//Determine #threads and #blocks
const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p_in), "parameter tensor is too large to be indexed with int32");
//Constants
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_FLOAT_HALF_AND_BYTE(p_in.scalar_type(), 0, "maybe_cast_cuda"
DISPATCH_FLOAT_HALF_AND_BYTE(p_out.scalar_type(), 1, "maybe_cast_cuda",
maybe_cast_kernel<scalar_t_0,scalar_t_1><<<blocks,threadsPerBlock, 0, stream>>>(
overflow_flag.numel() ? overflow_flag.DATA_PTR<int>() : NULL,
p_in.DATA_PTR<scalar_t_0>(),
p_out.DATA_PTR<scalar_t_1>(),
tsize); ))
C10_CUDA_CHECK(cudaGetLastError());
}
void maybe_cast_cuda_mt(
int chunk_size,
at::Tensor overflow_flag,
std::vector<std::vector<at::Tensor>> tensor_lists) // p_in, p_out
{
//Constants
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
size_t tl_sz = tensor_lists.size();
AT_ASSERTM(tl_sz == 2, "expected tensor lists of size 2");
DISPATCH_FLOAT_HALF_AND_BYTE(tensor_lists[0][0].scalar_type(), 0, "maybe_cast_cuda_mt_kernel",
DISPATCH_FLOAT_HALF_AND_BYTE(tensor_lists[1][0].scalar_type(), 1, "maybe_cast_cuda_mt_kernel",
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
overflow_flag,
tensor_lists,
MaybeCastFunctor<2, scalar_t_0, scalar_t_1>()); ))
C10_CUDA_CHECK(cudaGetLastError());
}
void fused_maybe_adam_undo_cuda(
at::Tensor & overflow_flag,
at::Tensor & p,
at::Tensor & m,
at::Tensor & v,
at::Tensor & g,
float lr,
float beta1,
float beta2,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay)
{
//Get tensor size
int tsize = p.numel();
//Determine #threads and #blocks
const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.scalar_type() == at::ScalarType::Half || g.scalar_type() == at::ScalarType::BFloat16) {
//all other values should be fp32 for half gradients
AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
maybe_adam_undo_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
overflow_flag.numel() ? overflow_flag.DATA_PTR<int>() : NULL,
p.DATA_PTR<accscalar_t>(),
m.DATA_PTR<accscalar_t>(),
v.DATA_PTR<accscalar_t>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
} else {
using namespace at;
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
maybe_adam_undo_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
overflow_flag.numel() ? overflow_flag.DATA_PTR<int>() : NULL,
p.DATA_PTR<scalar_t_0>(),
m.DATA_PTR<scalar_t_0>(),
v.DATA_PTR<scalar_t_0>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
}
C10_CUDA_CHECK(cudaGetLastError());
}
#include <torch/extension.h>
void multi_tensor_lamb_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr,
const float beta1,
const float beta2,
const float epsilon,
const int step,
const int bias_correction,
const float weight_decay,
const int grad_averaging,
const int mode,
const float global_grad_norm,
const float max_grad_norm);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("lamb", &multi_tensor_lamb_cuda, "Computes and apply update for LAMB optimizer");
}
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
typedef enum{
MOMENT_MODE_0 =0, // L2 regularization mode
MOMENT_MODE_1 =1 // Decoupled weight decay mode
} adamMode_t;
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);
using MATH_T = float;
template<typename T>
struct LAMBStage1Functor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<4>& tl,
const float beta1,
const float beta2,
const float beta3,
const float beta1_correction,
const float beta2_correction,
const float epsilon,
adamMode_t mode,
const float decay,
const float global_grad_norm,
const float max_global_grad_norm)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float clipped_global_grad_norm = global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : 1.0f;
T* g = (T*)tl.addresses[0][tensor_loc];
g += chunk_idx*chunk_size;
T* p = (T*)tl.addresses[1][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T*)tl.addresses[2][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T*)tl.addresses[3][tensor_loc];
v += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
// see note in multi_tensor_scale_kernel.cu
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
r_g[ii] = g[i];
// special ?optimization? for lamb stage 1
if (decay == 0) {
r_p[ii] = MATH_T(0);
}
else {
r_p[ii] = p[i];
}
r_m[ii] = m[i];
r_v[ii] = v[i];
} else {
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
r_v[ii] = MATH_T(0);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
// L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = next_m_unbiased / denom;
}
else {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
g[i] = r_p[ii];
m[i] = r_m[ii];
v[i] = r_v[ii];
}
}
}
}
};
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value.
template<typename T>
struct LAMBStage2Functor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<2>& tl,
const float* per_tensor_param_norm,
const float* per_tensor_update_norm,
const float learning_rate,
const float decay)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
MATH_T ratio = learning_rate;
// apply adaptive learning rate to parameters with non-zero weight decay
if (decay != 0.0)
{
float param_norm = per_tensor_param_norm[tensor_num];
float update_norm = per_tensor_update_norm[tensor_num];
ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
}
T* update = (T*)tl.addresses[0][tensor_loc];
update += chunk_idx*chunk_size;
T* p = (T*)tl.addresses[1][tensor_loc];
p += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
MATH_T r_p[ILP];
MATH_T r_update[ILP];
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
r_p[ii] = p[i];
r_update[ii] = update[i];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_p[ii] = r_p[ii] - (ratio * r_update[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
p[i] = r_p[ii];
}
}
}
}
};
void multi_tensor_lamb_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr,
const float beta1,
const float beta2,
const float epsilon,
const int step,
const int bias_correction,
const float weight_decay,
const int grad_averaging,
const int mode,
const float global_grad_norm,
const float max_grad_norm)
{
using namespace at;
// Master weight and 32bit momentum(potentially changing) is not handled by this
// So we assume every tensor are all in the same type
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
bias_correction1 = 1 - std::pow(beta1, step);
bias_correction2 = 1 - std::pow(beta2, step);
}
// Handle grad averaging mode
float beta3 = 1.0f;
if (grad_averaging == 1) beta3 = 1 - beta1;
std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), tensor_lists.begin()+1);
std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin()+1, tensor_lists.begin()+2);
// Compute per tensor param norm
auto param_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true);
// We now in-place modify grad to store update before compute its norm
// Generally this is not a issue since people modify grad in step() method all the time
// We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
LAMBStage1Functor<scalar_t_0>(),
beta1,
beta2,
beta3, // 1-beta1 or 1 depends on averaging mode
bias_correction1,
bias_correction2,
epsilon,
(adamMode_t) mode,
weight_decay,
global_grad_norm,
max_grad_norm); )
// Compute update norms
auto update_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, true);
std::vector<std::vector<at::Tensor>> grad_param_list(tensor_lists.begin(), tensor_lists.begin()+2);
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
noop_flag,
grad_param_list,
LAMBStage2Functor<scalar_t_0>(),
std::get<1>(param_norm_tuple).DATA_PTR<float>(),
std::get<1>(update_norm_tuple).DATA_PTR<float>(),
lr,
weight_decay); )
AT_CUDA_CHECK(cudaGetLastError());
}
#include <torch/extension.h>
void multi_tensor_fused_adam_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_beta1,
at::Tensor per_tensor_beta2,
at::Tensor per_tensor_bias_correction,
at::Tensor per_tensor_eps,
at::Tensor per_tensor_weight_decay,
float lr,
float grad_scale,
int step,
int mode);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_fused_adam", &multi_tensor_fused_adam_cuda,
"Multi tensor Adam optimized CUDA implementation.");
}
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include <cmath>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}
template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
typedef enum{
ADAM_MODE_0 =0, // eps under square root
ADAM_MODE_1 =1 // eps outside square root
} adamMode_t;
template <int DEPTH, typename T, typename GRAD_T>
struct DistAdamFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<DEPTH>& tl,
const float* per_tensor_beta1,
const float* per_tensor_beta2,
const int* per_tensor_bias_correction,
const float* per_tensor_eps,
const float* per_tensor_weight_decay,
const float lr,
const float grad_scale,
const int step,
adamMode_t mode)
{
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float b1 = per_tensor_beta1[tensor_num];
float b2 = per_tensor_beta2[tensor_num];
float eps = per_tensor_eps[tensor_num];
float decay = per_tensor_weight_decay[tensor_num];
float beta1_correction = 1.0f, beta2_correction = 1.0f;
if (per_tensor_bias_correction[tensor_num] == 1) {
beta1_correction = 1 - std::pow(b1, step);
beta2_correction = 1 - std::pow(b2, step);
}
T* p = (T *)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T *)tl.addresses[1][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T *)tl.addresses[2][tensor_loc];
v += chunk_idx*chunk_size;
GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];
g += chunk_idx*chunk_size;
GRAD_T* p_copy = NULL;
if (DEPTH == 5) {
p_copy = (GRAD_T *)tl.addresses[4][tensor_loc];
p_copy += chunk_idx*chunk_size;
}
n -= chunk_idx*chunk_size;
T incoming_p[ILP];
T incoming_m[ILP];
T incoming_v[ILP];
T incoming_g[ILP];
// to make things simple, we put aligned case in a different code path
if (n % ILP == 0 &&
chunk_size % ILP == 0 &&
is_aligned(p) &&
is_aligned(m) &&
is_aligned(v) &&
is_aligned(g) &&
is_aligned(p_copy)) {
for (int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) {
// load
GRAD_T tmp_g[ILP];
load_store(incoming_p, p, 0, i_start);
load_store(incoming_m, m, 0, i_start);
load_store(incoming_v, v, 0, i_start);
load_store(tmp_g, g, 0, i_start);
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
incoming_g[ii] = static_cast<T>(tmp_g[ii]);
T scaled_grad = incoming_g[ii]/grad_scale;
incoming_m[ii] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
incoming_v[ii] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
T next_m_unbiased = incoming_m[ii] / beta1_correction;
T next_v_unbiased = incoming_v[ii] / beta2_correction;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(next_v_unbiased + eps);
else // Mode 1
denom = sqrtf(next_v_unbiased) + eps;
float update = (next_m_unbiased / denom) + (decay * incoming_p[ii]);
incoming_p[ii] = incoming_p[ii] - (lr * update);
if (DEPTH == 5) tmp_g[ii] = static_cast<GRAD_T>(incoming_p[ii]);
}
load_store(p, incoming_p, i_start, 0);
load_store(m, incoming_m, i_start, 0);
load_store(v, incoming_v, i_start, 0);
if (DEPTH == 5) load_store(p_copy, tmp_g, i_start, 0);
}
} else {
for (int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
incoming_p[ii] = 0;
incoming_m[ii] = 0;
incoming_v[ii] = 0;
incoming_g[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if (i < n && i < chunk_size) {
incoming_p[ii] = p[i];
incoming_m[ii] = m[i];
incoming_v[ii] = v[i];
incoming_g[ii] = static_cast<T>(g[i]);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int j = i_start + threadIdx.x + ii*blockDim.x;
if (j < n && j < chunk_size) {
T scaled_grad = incoming_g[ii]/grad_scale;
m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
T next_m_unbiased = m[j] / beta1_correction;
T next_v_unbiased = v[j] / beta2_correction;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(next_v_unbiased + eps);
else // Mode 1
denom = sqrtf(next_v_unbiased) + eps;
float update = (next_m_unbiased / denom) + (decay * incoming_p[ii]);
p[j] = incoming_p[ii] - (lr * update);
if (DEPTH == 5) p_copy[j] = (GRAD_T) p[j];
}
}
}
}
}
};
void multi_tensor_fused_adam_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy
at::Tensor per_tensor_beta1,
at::Tensor per_tensor_beta2,
at::Tensor per_tensor_bias_correction,
at::Tensor per_tensor_eps,
at::Tensor per_tensor_weight_decay,
float lr,
float grad_scale,
int step,
int mode)
{
using namespace at;
size_t tl_sz = tensor_lists.size();
AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5");
if (tl_sz == 5) {
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[3][0].scalar_type(), 0, "dist_adam_cuda_kernel", // g
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
DistAdamFunctor<5, accscalar_t, scalar_t_0>(),
per_tensor_beta1.DATA_PTR<float>(),
per_tensor_beta2.DATA_PTR<float>(),
per_tensor_bias_correction.DATA_PTR<int>(),
per_tensor_eps.DATA_PTR<float>(),
per_tensor_weight_decay.DATA_PTR<float>(),
lr,
grad_scale,
step,
(adamMode_t) mode);
);
} else {
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[3][0].scalar_type(), 0, "dist_adam_cuda_kernel", // g
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
DistAdamFunctor<4, accscalar_t, scalar_t_0>(),
per_tensor_beta1.DATA_PTR<float>(),
per_tensor_beta2.DATA_PTR<float>(),
per_tensor_bias_correction.DATA_PTR<int>(),
per_tensor_eps.DATA_PTR<float>(),
per_tensor_weight_decay.DATA_PTR<float>(),
lr,
grad_scale,
step,
(adamMode_t) mode);
);
}
C10_CUDA_CHECK(cudaGetLastError());
}
#include <torch/extension.h>
void multi_tensor_lamb_compute_update_term_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_beta1,
at::Tensor per_tensor_beta2,
at::Tensor per_tensor_beta3,
at::Tensor per_tensor_bias_correction,
at::Tensor step,
at::Tensor per_tensor_epsilon,
const int mode,
at::Tensor per_tensor_decay,
at::Tensor global_scale,
at::Tensor global_grad_norm,
const float max_grad_norm);
void multi_tensor_lamb_update_weights_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm,
at::Tensor update_norm_offset,
at::Tensor learning_rate,
at::Tensor per_tensor_decay,
at::Tensor global_grad_norm,
bool use_nvlamb);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_lamb_compute_update_term", &multi_tensor_lamb_compute_update_term_cuda,
"Computes update term for LAMB optimizer");
m.def("multi_tensor_lamb_update_weights", &multi_tensor_lamb_update_weights_cuda,
"Applies update term for LAMB optimizer");
}
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}
template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
template <typename FROM_T, typename TO_T>
__device__ void convert(const FROM_T vi, TO_T& vo)
{
vo = static_cast<TO_T>(vi);
}
template <>
__device__ void convert(const float vi, uint8_t& vo)
{
union S
{
float as_float;
int as_int;
};
S s;
s.as_float = vi;
s.as_int = s.as_int & 0xFF800000;
union T
{
at::Half as_half;
uint8_t as_byte[2];
};
T t;
t.as_half = static_cast<at::Half>(vi + s.as_float / 8.0f);
vo = t.as_byte[1];
}
template <>
__device__ void convert(const uint8_t vi, float& vo)
{
union T
{
at::Half as_half;
uint8_t as_byte[2];
};
T t;
t.as_byte[0] = 0;
t.as_byte[1] = vi;
vo = static_cast<float>(t.as_half);
}
template <>
__device__ void convert(const at::Half vi, uint8_t& vo)
{
union S
{
float as_float;
int as_int;
};
S s;
s.as_float = static_cast<float>(vi);
s.as_int = s.as_int & 0xFF800000;
union T
{
at::Half as_half;
uint8_t as_byte[2];
};
T t;
t.as_half = static_cast<at::Half>(vi + s.as_float / 8.0f);
vo = t.as_byte[1];
}
template <>
__device__ void convert(const uint8_t vi, at::Half& vo)
{
union T
{
at::Half as_half;
uint8_t as_byte[2];
};
T t;
t.as_byte[0] = 0;
t.as_byte[1] = vi;
vo = t.as_half;
}
typedef enum{
MOMENT_MODE_0 =0, // L2 regularization mode
MOMENT_MODE_1 =1 // Decoupled weight decay mode
} adamMode_t;
template<typename T, typename GRAD_T, typename MATH_T>
struct DistOptLAMBStage1Functor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<5>& tl,
const MATH_T* per_tensor_beta1,
const MATH_T* per_tensor_beta2,
const MATH_T* per_tensor_beta3,
const int* per_tensor_bias_correction,
const int* step,
const MATH_T* per_tensor_epsilon,
adamMode_t mode,
const MATH_T* per_tensor_decay,
const MATH_T* global_scale,
const MATH_T* global_grad_norm,
const float max_grad_norm)
{
// I'd like this kernel to propagate infs/nans.
if (*noop_gmem == 1)
return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float combined_scale = *global_scale;
if (max_grad_norm > 0) {
combined_scale = max_grad_norm / (*global_grad_norm / *global_scale + 1e-6);
combined_scale = *global_scale / std::min((float) 1.0, combined_scale);
}
MATH_T beta1 = per_tensor_beta1[tensor_num];
MATH_T beta2 = per_tensor_beta2[tensor_num];
MATH_T beta3 = 1 - beta1;
MATH_T beta1_correction, beta2_correction;
if (per_tensor_bias_correction[tensor_num] == 1) {
beta1_correction = 1 - pow(beta1, *step);
beta2_correction = 1 - pow(beta2, *step);
} else {
beta1_correction = (MATH_T) 1.0;
beta2_correction = (MATH_T) 1.0;
}
MATH_T epsilon = per_tensor_epsilon[tensor_num];
MATH_T decay = per_tensor_decay[tensor_num];
GRAD_T* g = (GRAD_T*)tl.addresses[0][tensor_loc];
g += chunk_idx*chunk_size;
T* p = (T*)tl.addresses[1][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T*)tl.addresses[2][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T*)tl.addresses[3][tensor_loc];
v += chunk_idx*chunk_size;
MATH_T* u = (MATH_T*)tl.addresses[4][tensor_loc];
u += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
// to make things simple, we put aligned case in a different code path
if(n % ILP == 0 &&
chunk_size % ILP == 0 &&
is_aligned(g) &&
is_aligned(p) &&
is_aligned(m) &&
is_aligned(v))
{
GRAD_T l_g[ILP];
T l_p[ILP];
T l_m[ILP];
T l_v[ILP];
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
{
// load
load_store(l_g, g, 0, i_start);
if (decay != 0)
load_store(l_p, p, 0, i_start);
load_store(l_m, m, 0, i_start);
load_store(l_v, v, 0, i_start);
// unpack
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_g[ii] = l_g[ii];
if (decay == 0) {
r_p[ii] = MATH_T(0);
}
else {
r_p[ii] = l_p[ii];
}
r_m[ii] = l_m[ii];
r_v[ii] = l_v[ii];
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / combined_scale;
// L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = next_m_unbiased / denom;
}
else {
MATH_T scaled_grad = r_g[ii] / combined_scale;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
l_m[ii] = r_m[ii];
l_v[ii] = r_v[ii];
}
// store
load_store(u, r_p, i_start, 0);
load_store(m, l_m, i_start, 0);
load_store(v, l_v, i_start, 0);
}
}
else
{
// see note in multi_tensor_scale_kernel.cu
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
r_g[ii] = g[i];
// special ?optimization? for lamb stage 1
if (decay == 0) {
r_p[ii] = MATH_T(0);
}
else {
r_p[ii] = p[i];
}
r_m[ii] = m[i];
r_v[ii] = v[i];
} else {
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
r_v[ii] = MATH_T(0);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / combined_scale;
// L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = next_m_unbiased / denom;
}
else {
MATH_T scaled_grad = r_g[ii] / combined_scale;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
u[i] = r_p[ii];
m[i] = r_m[ii];
v[i] = r_v[ii];
}
}
}
}
}
};
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value.
template<typename T, typename GRAD_T, typename MATH_T>
struct DistOptLAMBStage2Functor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<3>& tl,
const MATH_T* per_tensor_param_norm,
const MATH_T* per_tensor_update_norm,
const long* update_norm_offset,
const MATH_T* learning_rate,
const MATH_T* per_tensor_decay,
const MATH_T* global_grad_norm,
bool use_nvlamb)
{
// I'd like this kernel to propagate infs/nans.
if (*noop_gmem == 1)
return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
MATH_T decay = per_tensor_decay[tensor_num];
MATH_T ratio = *learning_rate;
// nvlamb: apply adaptive learning rate to all parameters
// otherwise, only apply to those with non-zero weight decay
if (use_nvlamb || (decay != (MATH_T) 0.0))
{
MATH_T param_norm = per_tensor_param_norm[tensor_num];
MATH_T update_norm = per_tensor_update_norm[update_norm_offset[tensor_num]];
ratio = (update_norm != 0.0 && param_norm != 0.0) ? (*learning_rate) * (param_norm / update_norm) : (*learning_rate);
}
MATH_T* update = (MATH_T*)tl.addresses[0][tensor_loc];
update += chunk_idx*chunk_size;
T* p = (T*)tl.addresses[1][tensor_loc];
p += chunk_idx*chunk_size;
GRAD_T* p_copy = (GRAD_T*)tl.addresses[2][tensor_loc];
p_copy += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
// to make things simple, we put aligned case in a different code path
if(n % ILP == 0 &&
chunk_size % ILP == 0 &&
is_aligned(p) &&
is_aligned(update))
{
T r_p[ILP];
MATH_T r_update[ILP];
GRAD_T r_p_copy[ILP];
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
{
// load
load_store(r_p, p, 0, i_start);
load_store(r_update, update, 0, i_start);
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_p[ii] = static_cast<MATH_T>(r_p[ii]) - (ratio * r_update[ii]);
convert(r_p[ii], r_p_copy[ii]);
}
load_store(p, r_p, i_start, 0);
load_store(p_copy, r_p_copy, i_start, 0);
}
}
else
{
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
MATH_T r_p[ILP];
MATH_T r_update[ILP];
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
r_p[ii] = p[i];
r_update[ii] = update[i];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_p[ii] = r_p[ii] - (ratio * r_update[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
p[i] = r_p[ii];
convert(r_p[ii], p_copy[i]);
}
}
}
}
}
};
void multi_tensor_lamb_compute_update_term_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_beta1,
at::Tensor per_tensor_beta2,
at::Tensor per_tensor_beta3,
at::Tensor per_tensor_bias_correction,
at::Tensor step,
at::Tensor per_tensor_epsilon,
const int mode,
at::Tensor per_tensor_decay,
at::Tensor global_scale,
at::Tensor global_grad_norm,
const float max_grad_norm)
{
using namespace at;
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 0, "lamb_stage_1",
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 1, "lamb_stage_1",
DISPATCH_FLOAT_AND_HALF(tensor_lists[4][0].scalar_type(), 2, "lamb_stage_1",
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
DistOptLAMBStage1Functor<scalar_t_0, scalar_t_1, scalar_t_2>(),
per_tensor_beta1.DATA_PTR<scalar_t_2>(),
per_tensor_beta2.DATA_PTR<scalar_t_2>(),
per_tensor_beta3.DATA_PTR<scalar_t_2>(),
per_tensor_bias_correction.DATA_PTR<int>(),
step.DATA_PTR<int>(),
per_tensor_epsilon.DATA_PTR<scalar_t_2>(),
(adamMode_t) mode,
per_tensor_decay.DATA_PTR<scalar_t_2>(),
global_scale.DATA_PTR<scalar_t_2>(),
global_grad_norm.DATA_PTR<scalar_t_2>(),
max_grad_norm); )))
AT_CUDA_CHECK(cudaGetLastError());
}
void multi_tensor_lamb_update_weights_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm,
at::Tensor update_norm_offset,
at::Tensor learning_rate,
at::Tensor per_tensor_decay,
at::Tensor global_grad_norm,
bool use_nvlamb)
{
using namespace at;
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 0, "lamb_stage_2",
DISPATCH_FLOAT_HALF_AND_BYTE(tensor_lists[2][0].scalar_type(), 1, "lamb_stage_2",
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 2, "lamb_stage_2",
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
DistOptLAMBStage2Functor<scalar_t_0, scalar_t_1, scalar_t_2>(),
per_tensor_param_norm.DATA_PTR<scalar_t_2>(),
per_tensor_update_norm.DATA_PTR<scalar_t_2>(),
update_norm_offset.DATA_PTR<long>(),
learning_rate.DATA_PTR<scalar_t_2>(),
per_tensor_decay.DATA_PTR<scalar_t_2>(),
global_grad_norm.DATA_PTR<scalar_t_2>(),
use_nvlamb); )))
AT_CUDA_CHECK(cudaGetLastError());
}
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "peer_memory_cuda.cuh"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("allocate_raw", &apex::contrib::peer_memory::allocate_raw, "allocate_raw");
m.def("free_raw", &apex::contrib::peer_memory::free_raw, "free_raw");
m.def("zero", &apex::contrib::peer_memory::zero, "zero");
m.def("get_raw_ipc_address", &apex::contrib::peer_memory::get_raw_ipc_address, "get_raw_ipc_address");
m.def("get_raw_peers", &apex::contrib::peer_memory::get_raw_peers, "get_raw_peers");
m.def("blob_view_half", &apex::contrib::peer_memory::blob_view_half, "blob_view_half");
m.def("blob_view_float", &apex::contrib::peer_memory::blob_view_float, "blob_view_float");
m.def("blob_view_int", &apex::contrib::peer_memory::blob_view_int, "blob_view_int");
m.def("push_pull_halos_1d", &apex::contrib::peer_memory::push_pull_halos_1d, "push_pull_halos_1d");
}
#include <torch/extension.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <ATen/cuda/CUDAContext.h>
#include <list>
#include <cstdio>
#include <cassert>
#include <cuda_runtime_api.h>
#ifdef __HIP_PLATFORM_HCC__
#include <hip/hip_cooperative_groups.h>
#include "rccl/rccl.h"
#else
#include <cooperative_groups.h>
#include "nccl.h"
#endif
namespace cg = cooperative_groups;
#define CUDACHECK(cmd) do { \
cudaError_t err = cmd; \
if( err != cudaSuccess ) { \
char hostname[1024]; \
gethostname(hostname, 1024); \
printf("%s: CUDA failure %s:%d '%s'\n", \
hostname, \
__FILE__,__LINE__,cudaGetErrorString(err)); \
} \
} while(0)
// C++17 removes 'register' storage keyword
#if __cplusplus < 201703L
#define REGISTER register
#else
#define REGISTER
#endif
namespace {
/* Basic deleter function for from_blob function.
void deleter(void* ptr)
{
printf("deleter(ptr=%p)\n",ptr);
cudaFree(ptr);
}
*/
template<class T>
at::Tensor blob_view(T* raw_ptr, std::vector<int64_t> shape, const at::TensorOptions& options, bool channels_last)
{
size_t size = 1;
std::vector<int64_t> strides(shape.size());
if (channels_last) {
assert(shape.size() == 4);
strides[0] = shape[1]*shape[2]*shape[3];
strides[1] = 1;
strides[2] = shape[1]*shape[3];
strides[3] = shape[1];
} else {
int idx = strides.size();
for (auto it = shape.rbegin(); it != shape.rend(); ++it)
{
strides[--idx] = size;
size *= *it;
}
}
size *= sizeof(T);
// TODO: Implement dynamic reuse of pooled peer memory.
// We provide no deleter function because all peer memory allocations are static in this implementation.
return torch::from_blob((void*)raw_ptr, shape, strides, 0L, options);
}
void tensor_shape(at::Tensor t, bool explicit_nhwc, int& N, int& C, int& H, int& W)
{
if (t.dim() == 3) {
N = 1;
if (explicit_nhwc) {
C = t.size(2);
H = t.size(0);
W = t.size(1);
} else {
C = t.size(0);
H = t.size(1);
W = t.size(2);
}
} else if (t.dim() == 4) {
if (explicit_nhwc) {
N = t.size(0);
C = t.size(3);
H = t.size(1);
W = t.size(2);
} else {
N = t.size(0);
C = t.size(1);
H = t.size(2);
W = t.size(3);
}
} else {
printf("%s;%d - t.dim() must be either 3 or 4 (was %d)\n",__FILE__,__LINE__,t.dim());
assert(t.dim() == 3 || t.dim() == 4);
}
}
void tensor_strides(at::Tensor t, bool explicit_nhwc, int& stride_N, int& stride_C, int& stride_H, int& stride_W)
{
if (t.dim() == 3) {
if (explicit_nhwc) {
stride_C = t.stride(2);
stride_H = t.stride(0);
stride_W = t.stride(1);
} else {
stride_C = t.stride(0);
stride_H = t.stride(1);
stride_W = t.stride(2);
}
stride_N = t.size(0)*t.size(1)*t.size(2);
} else if (t.dim() == 4) {
if (explicit_nhwc) {
stride_N = t.stride(0);
stride_C = t.stride(3);
stride_H = t.stride(1);
stride_W = t.stride(2);
} else {
stride_N = t.stride(0);
stride_C = t.stride(1);
stride_H = t.stride(2);
stride_W = t.stride(3);
}
} else {
printf("%s;%d - t.dim() must be either 3 or 4 (was %d)\n",__FILE__,__LINE__,t.dim());
assert(t.dim() == 3 || t.dim() == 4);
}
}
template<class T>
__device__ void __zero(T* dst)
{
*dst = T(0);
}
__device__ void __zero(int4* dst)
{
int4 v;
v.x = v.y = v.z = v.w = 0;
*dst = v;
}
template<class T, bool is_HWC, bool zero>
__device__ void strided_copy_kernel(
T* dst, const int dst_stride_C, const int dst_stride_H, const int dst_stride_W,
const T* src, const int src_stride_C, const int src_stride_H, const int src_stride_W,
const int NC, const int NH, const int NW
)
{
size_t tot_num_threads = gridDim.x * blockDim.x;
size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
const size_t count = NC*NH*NW;
for (size_t i = thread_id; i < count; i += tot_num_threads)
{
size_t c,h,w;
if (is_HWC) {
w = i / NC;
c = i - w * NC;
h = w / NW;
w = w - h * NW;
}
else {
h = i / NW;
w = i - h * NW;
c = h / NH;
h = h - c * NH;
}
size_t dst_off = c*dst_stride_C + h*dst_stride_H + w*dst_stride_W;
if (zero) {
__zero(dst+dst_off);
} else {
size_t src_off = c*src_stride_C + h*src_stride_H + w*src_stride_W;
dst[dst_off] = src[src_off];
}
}
}
template<bool top_zero, bool btm_zero>
__device__ void checked_signal(
volatile int* signal1_flag, volatile int* signal2_flag,
const int v1, const int v2, const int v3, const int v4
)
{
cg::this_grid().sync();
bool is_main_thread = (blockIdx.x == 0 && threadIdx.x == 0) ? true : false;
if (is_main_thread) {
// flush all writes to global memory
__threadfence_system();
// wait for top or bottom neighbor to clear signal
REGISTER int r1, r2, r3, r4;
if (!(top_zero || btm_zero)) {
bool top_zeroed=false, top_done=false;
bool btm_zeroed=false, btm_done=false;
do {
do {
if (!top_zeroed) {
#ifdef __HIP_PLATFORM_HCC__
r1 = __builtin_nontemporal_load(signal1_flag);
r2 = __builtin_nontemporal_load(signal1_flag + 1);
r3 = __builtin_nontemporal_load(signal1_flag + 2);
r4 = __builtin_nontemporal_load(signal1_flag + 3);
#else
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal1_flag) : "memory");
#endif
if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) top_zeroed = true;
}
if (!btm_zeroed) {
#ifdef __HIP_PLATFORM_HCC__
r1 = __builtin_nontemporal_load(signal2_flag);
r2 = __builtin_nontemporal_load(signal2_flag + 1);
r3 = __builtin_nontemporal_load(signal2_flag + 2);
r4 = __builtin_nontemporal_load(signal2_flag + 3);
#else
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal2_flag) : "memory");
#endif
if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) btm_zeroed = true;
}
} while((top_zeroed == top_done) && (btm_zeroed == btm_done));
if (!top_done && top_zeroed) {
// signal to top neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store(v1, signal1_flag);
__builtin_nontemporal_store(v2, signal1_flag + 1);
__builtin_nontemporal_store(v3, signal1_flag + 2);
__builtin_nontemporal_store(v4, signal1_flag + 3);
#else
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal1_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
#endif
top_done = true;
}
if (!btm_done && btm_zeroed) {
// signal to bottom neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store(v1, signal2_flag);
__builtin_nontemporal_store(v2, signal2_flag + 1);
__builtin_nontemporal_store(v3, signal2_flag + 2);
__builtin_nontemporal_store(v4, signal2_flag + 3);
#else
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal2_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
#endif
btm_done = true;
}
} while (!top_done || !btm_done);
} else if (top_zero) {
bool btm_zeroed=false, btm_done=false;
do {
do {
if (!btm_zeroed) {
#ifdef __HIP_PLATFORM_HCC__
r1 = __builtin_nontemporal_load(signal2_flag);
r2 = __builtin_nontemporal_load(signal2_flag + 1);
r3 = __builtin_nontemporal_load(signal2_flag + 2);
r4 = __builtin_nontemporal_load(signal2_flag + 3);
#else
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal2_flag) : "memory");
#endif
if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) btm_zeroed = true;
}
} while(btm_zeroed == btm_done);
if (!btm_done && btm_zeroed) {
// signal to bottom neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store(v1, signal2_flag);
__builtin_nontemporal_store(v2, signal2_flag + 1);
__builtin_nontemporal_store(v3, signal2_flag + 2);
__builtin_nontemporal_store(v4, signal2_flag + 3);
#else
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal2_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
#endif
btm_done = true;
}
} while (!btm_done);
} else if (btm_zero) {
bool top_zeroed=false, top_done=false;
do {
do {
if (!top_zeroed) {
#ifdef __HIP_PLATFORM_HCC__
r1 = __builtin_nontemporal_load(signal1_flag);
r2 = __builtin_nontemporal_load(signal1_flag + 1);
r3 = __builtin_nontemporal_load(signal1_flag + 2);
r4 = __builtin_nontemporal_load(signal1_flag + 3);
#else
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal1_flag) : "memory");
#endif
if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) top_zeroed = true;
}
} while(top_zeroed == top_done);
if (!top_done && top_zeroed) {
// signal to top neighbor my output is ready
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store(v1, signal1_flag);
__builtin_nontemporal_store(v2, signal1_flag + 1);
__builtin_nontemporal_store(v3, signal1_flag + 2);
__builtin_nontemporal_store(v4, signal1_flag + 3);
#else
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal1_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
#endif
top_done = true;
}
} while (!top_done);
}
}
}
__device__ void wait_for(
volatile int* wait_flag,
const int v1, const int v2, const int v3, const int v4
)
{
bool is_main_thread = (blockIdx.x == 0 && threadIdx.x == 0) ? true : false;
if (is_main_thread) {
REGISTER int r1, r2, r3, r4;
// wait for senders to signal their output is read
do {
#ifdef __HIP_PLATFORM_HCC__
r1 = __builtin_nontemporal_load(wait_flag);
r2 = __builtin_nontemporal_load(wait_flag + 1);
r3 = __builtin_nontemporal_load(wait_flag + 2);
r4 = __builtin_nontemporal_load(wait_flag + 3);
#else
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(wait_flag) : "memory");
#endif
} while (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4);
}
cg::this_grid().sync(); // all threads wait for main
}
__device__ void clear_flag(
volatile int* wait_flag
)
{
cg::this_grid().sync(); // wait for all threads in kernel to finish
bool is_main_thread = (blockIdx.x == 0 && threadIdx.x == 0) ? true : false;
if (is_main_thread) {
REGISTER int r1, r2, r3, r4;
r1 = 0; r2 = 0; r3 = 0; r4 = 0;
#ifdef __HIP_PLATFORM_HCC__
__builtin_nontemporal_store(r1, wait_flag);
__builtin_nontemporal_store(r2, wait_flag + 1);
__builtin_nontemporal_store(r3, wait_flag + 2);
__builtin_nontemporal_store(r4, wait_flag + 3);
#else
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(wait_flag), "r"(r1), "r"(r2), "r"(r3), "r"(r4) : "memory");
#endif
}
}
template<class T, bool is_HWC, bool top_zero, bool btm_zero>
#if __CUDA_ARCH__ == 700 || __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 900
__launch_bounds__(128, 16)
#endif
__global__ void push_pull_halos_1d_kernel(
// top halo,
const T* toh, int toh_stride_C, int toh_stride_H, int toh_stride_W, // top output halo
T* tox, int tox_stride_C, int tox_stride_H, int tox_stride_W, // top output tx buffer
T* tix, int tix_stride_C, int tix_stride_H, int tix_stride_W, // top input tx buffer
T* tih, int tih_stride_C, int tih_stride_H, int tih_stride_W, // top input halo
// btm halo
const T* boh, int boh_stride_C, int boh_stride_H, int boh_stride_W, // btm output halo
T* box, int box_stride_C, int box_stride_H, int box_stride_W, // btm output tx buffer
T* bix, int bix_stride_C, int bix_stride_H, int bix_stride_W, // btm input tx buffer
T* bih, int bih_stride_C, int bih_stride_H, int bih_stride_W, // btm input halo
// dimensions
int NC, int NH, int NW,
// signals
int* signal1_flag,
int* signal2_flag,
int* wait1_flag,
int* wait2_flag
)
{
// push top output halo to transfer buffer
if (!top_zero) strided_copy_kernel<T,is_HWC,false>(tox, tox_stride_C, tox_stride_H, tox_stride_W, toh, toh_stride_C, toh_stride_H, toh_stride_W, NC, NH, NW);
// push btm output halo to transfer buffer
if (!btm_zero) strided_copy_kernel<T,is_HWC,false>(box, box_stride_C, box_stride_H, box_stride_W, boh, boh_stride_C, boh_stride_H, boh_stride_W, NC, NH, NW);
// signal to top and btm neigbhbors that output halos are ready to be read
// the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values
if (!(top_zero || btm_zero)) {
checked_signal<false,false>(signal1_flag, signal2_flag, -987751720, 840868300, -225529332, 281513358);
} else if (top_zero) {
checked_signal<true,false>(signal1_flag, signal2_flag, -987751720, 840868300, -225529332, 281513358);
} else if (btm_zero) {
checked_signal<false,true>(signal1_flag, signal2_flag, -987751720, 840868300, -225529332, 281513358);
}
// pull top halo from transfer buffer in peer memory to input
if (top_zero) {
strided_copy_kernel<T,is_HWC,true>(tih, tih_stride_C, tih_stride_H, tih_stride_W, tix, tix_stride_C, tix_stride_H, tix_stride_W, NC, NH, NW);
} else {
wait_for(wait1_flag, -987751720, 840868300, -225529332, 281513358);
strided_copy_kernel<T,is_HWC,false>(tih, tih_stride_C, tih_stride_H, tih_stride_W, tix, tix_stride_C, tix_stride_H, tix_stride_W, NC, NH, NW);
clear_flag(wait1_flag);
}
// pull btm halo from transfer buffer in peer memory to input
if (btm_zero) {
strided_copy_kernel<T,is_HWC,true>(bih, bih_stride_C, bih_stride_H, bih_stride_W, bix, bix_stride_C, bix_stride_H, bix_stride_W, NC, NH, NW);
} else {
wait_for(wait2_flag, -987751720, 840868300, -225529332, 281513358);
strided_copy_kernel<T,is_HWC,false>(bih, bih_stride_C, bih_stride_H, bih_stride_W, bix, bix_stride_C, bix_stride_H, bix_stride_W, NC, NH, NW);
clear_flag(wait2_flag);
}
}
__global__ void delay_kernel(int delay_nanoseconds, int* counter)
{
if (blockIdx.x == 0 && threadIdx.x == 0) {
// waste time while doing something compiler can't predict, thus preventing it from optimizing away this code.
int new_counter = 0;
double elapsed = 0;
clock_t start = clock();
do {
clock_t now = clock();
elapsed = (double)(now - start)*1e9 / CLOCKS_PER_SEC;
++new_counter;
} while (elapsed < (double)delay_nanoseconds);
*counter = new_counter;
}
}
}
namespace apex { namespace contrib { namespace peer_memory {
int64_t allocate_raw(int64_t size)
{
float* ptr = 0L;
cudaMalloc(&ptr, size);
cudaMemset(ptr, 0, size);
return (int64_t)ptr;
}
void free_raw(int64_t raw)
{
cudaFree((void*)raw);
}
void zero(int64_t raw, int64_t size)
{
cudaMemset((void*)raw, 0, size);
}
at::Tensor get_raw_ipc_address(int64_t raw)
{
cudaIpcMemHandle_t mem_handle;
CUDACHECK( cudaIpcGetMemHandle(&mem_handle, (void*)raw) );
const int n = sizeof(cudaIpcMemHandle_t);
auto address_tensor = torch::empty({n}, torch::dtype(torch::kUInt8));
auto address_tensor_p = address_tensor.data_ptr<uint8_t>();
memcpy(address_tensor_p, (uint8_t*)&mem_handle, n);
return address_tensor;
}
std::vector<int64_t> get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int64_t raw)
{
int peer_group_size = ipc_addresses.size(0);
std::vector<int64_t> results(peer_group_size);
for (int i = 0; i < peer_group_size; ++i) {
if (i != peer_rank) {
cudaIpcMemHandle_t mem_handle;
memcpy(&mem_handle, ipc_addresses.index({i}).data_ptr<uint8_t>(), sizeof(cudaIpcMemHandle_t));
void* p = 0L;
CUDACHECK( cudaIpcOpenMemHandle((void**)&p, mem_handle, cudaIpcMemLazyEnablePeerAccess) );
results[i] = (int64_t)p;
} else {
results[i] = (int64_t)raw;
}
}
return results;
}
at::Tensor blob_view_half(int64_t raw, std::vector<int64_t> shape, bool channels_last)
{
return blob_view<at::Half>((at::Half*)raw, shape, torch::dtype(torch::kFloat16).device(torch::kCUDA), channels_last);
}
at::Tensor blob_view_float(int64_t raw, std::vector<int64_t> shape, bool channels_last)
{
return blob_view<float>((float*)raw, shape, torch::dtype(torch::kFloat32).device(torch::kCUDA), channels_last);
}
at::Tensor blob_view_int(int64_t raw, std::vector<int64_t> shape, bool channels_last)
{
return blob_view<int>((int*)raw, shape, torch::dtype(torch::kInt32).device(torch::kCUDA), channels_last);
}
void push_pull_halos_1d(
bool diagnostics,
bool explicit_nhwc,
int numSM, // number of SMs to use
bool top_zero, // true if top halo should be zeroed
at::Tensor top_out_halo, // top output halo in sender device memory
at::Tensor top_out_tx, // top output transfer buffer in sender peer pool memory
at::Tensor top_inp_tx, // top input transfer buffer in top neighbor peer pool memory
at::Tensor top_inp_halo, // top input halo in receiver device memory
bool btm_zero, // true if btm halo should be zeroed
at::Tensor btm_out_halo, // btm output halo in sender device memory
at::Tensor btm_out_tx, // btm output transfer buffer in sender peer pool memory
at::Tensor btm_inp_tx, // btm input transfer buffer in btm neighbor peer pool memory
at::Tensor btm_inp_halo, // btm input halo in receiver device memory
at::Tensor top_signal, // top input signal in receiver device memory
at::Tensor btm_signal, // btm input signal in receiver device memory
at::Tensor waits // top and btm signals for this rank
)
{
// basic checks of inputs
TORCH_CHECK(top_out_halo.is_cuda());
TORCH_CHECK(top_out_tx.is_cuda());
TORCH_CHECK(top_inp_tx.is_cuda());
TORCH_CHECK(top_inp_halo.is_cuda());
TORCH_CHECK(btm_out_halo.is_cuda());
TORCH_CHECK(btm_out_tx.is_cuda());
TORCH_CHECK(btm_inp_tx.is_cuda());
TORCH_CHECK(btm_inp_halo.is_cuda());
TORCH_CHECK(top_signal.is_cuda());
TORCH_CHECK(btm_signal.is_cuda());
TORCH_CHECK(waits.is_cuda());
TORCH_CHECK(!(top_zero && btm_zero));
// shapes and strides
int toh_N, toh_C, toh_H, toh_W;
tensor_shape(top_out_halo, explicit_nhwc, toh_N, toh_C, toh_H, toh_W);
int tox_N, tox_C, tox_H, tox_W;
tensor_shape(top_out_tx, explicit_nhwc, tox_N, tox_C, tox_H, tox_W);
int tix_N, tix_C, tix_H, tix_W;
tensor_shape(top_inp_tx, explicit_nhwc, tix_N, tix_C, tix_H, tix_W);
int tih_N, tih_C, tih_H, tih_W;
tensor_shape(top_inp_halo, explicit_nhwc, tih_N, tih_C, tih_H, tih_W);
TORCH_CHECK(
(toh_N == tox_N && tox_N == tix_N && tix_N == tih_N) &&
(toh_C == tox_C && tox_C == tix_C && tix_C == tih_C) &&
(toh_H == tox_H && tox_H == tix_H && tix_H == tih_H) &&
(toh_W == tox_W && tox_W == tix_W && tix_W == tih_W));
int boh_N, boh_C, boh_H, boh_W;
tensor_shape(btm_out_halo, explicit_nhwc, boh_N, boh_C, boh_H, boh_W);
int box_N, box_C, box_H, box_W;
tensor_shape(btm_out_tx, explicit_nhwc, box_N, box_C, box_H, box_W);
int bix_N, bix_C, bix_H, bix_W;
tensor_shape(btm_inp_tx, explicit_nhwc, bix_N, bix_C, bix_H, bix_W);
int bih_N, bih_C, bih_H, bih_W;
tensor_shape(btm_inp_halo, explicit_nhwc, bih_N, bih_C, bih_H, bih_W);
TORCH_CHECK(
(boh_N == box_N && box_N == bix_N && bix_N == bih_N) &&
(boh_C == box_C && box_C == bix_C && bix_C == bih_C) &&
(boh_H == box_H && box_H == bix_H && bix_H == bih_H) &&
(boh_W == box_W && box_W == bix_W && bix_W == bih_W));
TORCH_CHECK(
(toh_N == boh_N) &&
(toh_C == boh_C) &&
(toh_H == boh_H) &&
(toh_W == boh_W));
int NC=toh_C, NH=toh_H, NW=toh_W;
if (diagnostics) printf("NC=%d, NH=%d, NW=%d\n",NC,NH,NW);
int toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W;
tensor_strides(top_out_halo, explicit_nhwc, toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W);
int tox_stride_N, tox_stride_C, tox_stride_H, tox_stride_W;
tensor_strides(top_out_tx, explicit_nhwc, tox_stride_N, tox_stride_C, tox_stride_H, tox_stride_W);
int tix_stride_N, tix_stride_C, tix_stride_H, tix_stride_W;
tensor_strides(top_inp_tx, explicit_nhwc, tix_stride_N, tix_stride_C, tix_stride_H, tix_stride_W);
int tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W;
tensor_strides(top_inp_halo, explicit_nhwc, tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W);
int boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W;
tensor_strides(btm_out_halo, explicit_nhwc, boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W);
int box_stride_N, box_stride_C, box_stride_H, box_stride_W;
tensor_strides(btm_out_tx, explicit_nhwc, box_stride_N, box_stride_C, box_stride_H, box_stride_W);
int bix_stride_N, bix_stride_C, bix_stride_H, bix_stride_W;
tensor_strides(btm_inp_tx, explicit_nhwc, bix_stride_N, bix_stride_C, bix_stride_H, bix_stride_W);
int bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W;
tensor_strides(btm_inp_halo, explicit_nhwc, bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W);
// determine if nhwc
auto is_nhwc = (toh_stride_C == 1) ? true : false;
if (diagnostics) printf("is_nhwc = %s\n",is_nhwc?"true":"false");
// figure out launch parameters
int device;
cudaGetDevice(&device);
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, device);
assert(numSM > 0 && numSM <= prop.multiProcessorCount);
auto current_stream = at::cuda::getCurrentCUDAStream();
const int numThreads = 128;
dim3 block(numThreads,1,1);
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, top_out_halo.scalar_type(), "push_pull_halos_1d_kernel", [&]{
if (diagnostics) printf("size(scalar_t) = %ld\n",sizeof(scalar_t));
scalar_t* toh_p = top_out_halo.data_ptr<scalar_t>();
scalar_t* tox_p = top_out_tx.data_ptr<scalar_t>();
scalar_t* tix_p = top_inp_tx.data_ptr<scalar_t>();
scalar_t* tih_p = top_inp_halo.data_ptr<scalar_t>();
scalar_t* boh_p = btm_out_halo.data_ptr<scalar_t>();
scalar_t* box_p = btm_out_tx.data_ptr<scalar_t>();
scalar_t* bix_p = btm_inp_tx.data_ptr<scalar_t>();
scalar_t* bih_p = btm_inp_halo.data_ptr<scalar_t>();
if (diagnostics) printf("waypoint1\n");
int* top_signal_p = top_signal.data_ptr<int>() + 4;
int* btm_signal_p = btm_signal.data_ptr<int>();
int* top_wait_p = waits.data_ptr<int>();
int* btm_wait_p = waits.data_ptr<int>() + 4;
if (diagnostics) printf("waypoint2\n");
// do int4 vector loads if channel count permits
int elem_size_in_bytes = toh_C * sizeof(scalar_t);
int elem_size_in_int4 = (elem_size_in_bytes / 16);
if (diagnostics) printf("elem_size_in_bytes = %d, elem_size_in_int4 = %d\n",elem_size_in_bytes,elem_size_in_int4);
if (is_nhwc && elem_size_in_int4*16 == elem_size_in_bytes) {
// can do int4 transfers
int divisor = toh_C / elem_size_in_int4;
if (diagnostics) printf("CAN DO INT4 :: divisor = %d\n",divisor);
toh_stride_N /= divisor; toh_stride_H /= divisor; toh_stride_W /= divisor;
tox_stride_N /= divisor; tox_stride_H /= divisor; tox_stride_W /= divisor;
tix_stride_N /= divisor; tix_stride_H /= divisor; tix_stride_W /= divisor;
tih_stride_N /= divisor; tih_stride_H /= divisor; tih_stride_W /= divisor;
boh_stride_N /= divisor; boh_stride_H /= divisor; boh_stride_W /= divisor;
box_stride_N /= divisor; box_stride_H /= divisor; box_stride_W /= divisor;
bix_stride_N /= divisor; bix_stride_H /= divisor; bix_stride_W /= divisor;
bih_stride_N /= divisor; bih_stride_H /= divisor; bih_stride_W /= divisor;
NC /= divisor;
if (diagnostics) {
printf("divisor=%d\n",divisor);
printf("toh_stride :: N=%d, C=%d, H=%d, W=%d\n",toh_stride_N,toh_stride_C,toh_stride_H,toh_stride_W);
printf("tox_stride :: N=%d, C=%d, H=%d, W=%d\n",tox_stride_N,tox_stride_C,tox_stride_H,tox_stride_W);
printf("tix_stride :: N=%d, C=%d, H=%d, W=%d\n",tix_stride_N,tix_stride_C,tix_stride_H,tix_stride_W);
printf("tih_stride :: N=%d, C=%d, H=%d, W=%d\n",tih_stride_N,tih_stride_C,tih_stride_H,tih_stride_W);
printf("boh_stride :: N=%d, C=%d, H=%d, W=%d\n",boh_stride_N,boh_stride_C,boh_stride_H,boh_stride_W);
printf("box_stride :: N=%d, C=%d, H=%d, W=%d\n",box_stride_N,box_stride_C,box_stride_H,box_stride_W);
printf("bix_stride :: N=%d, C=%d, H=%d, W=%d\n",bix_stride_N,bix_stride_C,bix_stride_H,bix_stride_W);
printf("bih_stride :: N=%d, C=%d, H=%d, W=%d\n",bih_stride_N,bih_stride_C,bih_stride_H,bih_stride_W);
printf("NC=%d, NH=%d, NW=%d\n",NC,NH,NW);
}
void *kernelArgs[] = {
(int4**)&toh_p, &toh_stride_C, &toh_stride_H, &toh_stride_W,
(int4**)&tox_p, &tox_stride_C, &tox_stride_H, &tox_stride_W,
(int4**)&tix_p, &tix_stride_C, &tix_stride_H, &tix_stride_W,
(int4**)&tih_p, &tih_stride_C, &tih_stride_H, &tih_stride_W,
(int4**)&boh_p, &boh_stride_C, &boh_stride_H, &boh_stride_W,
(int4**)&box_p, &box_stride_C, &box_stride_H, &box_stride_W,
(int4**)&bix_p, &bix_stride_C, &bix_stride_H, &bix_stride_W,
(int4**)&bih_p, &bih_stride_C, &bih_stride_H, &bih_stride_W,
&NC, &NH, &NW,
&top_signal_p, &btm_signal_p, &top_wait_p, &btm_wait_p
};
if (top_zero) {
int numBlocksPerSm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<int4,true,true,false>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true,true,false>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true,true,false>, grid, block, kernelArgs, 0, current_stream);
#endif
} else if (btm_zero) {
int numBlocksPerSm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<int4,true,false,true>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true,false,true>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true,false,true>, grid, block, kernelArgs, 0, current_stream);
#endif
} else {
int numBlocksPerSm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<int4,true,false,false>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true,false,false>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true,false,false>, grid, block, kernelArgs, 0, current_stream);
#endif
}
} else {
// cannot do int4 transfers
if (diagnostics) printf("CAN NOT DO INT4\n");
void *kernelArgs[] = {
&toh_p, &toh_stride_C, &toh_stride_H, &toh_stride_W,
&tox_p, &tox_stride_C, &tox_stride_H, &tox_stride_W,
&tix_p, &tix_stride_C, &tix_stride_H, &tix_stride_W,
&tih_p, &tih_stride_C, &tih_stride_H, &tih_stride_W,
&boh_p, &boh_stride_C, &boh_stride_H, &boh_stride_W,
&box_p, &box_stride_C, &box_stride_H, &box_stride_W,
&bix_p, &bix_stride_C, &bix_stride_H, &bix_stride_W,
&bih_p, &bih_stride_C, &bih_stride_H, &bih_stride_W,
&NC, &NH, &NW,
&top_signal_p, &btm_signal_p, &top_wait_p, &btm_wait_p
};
int numBlocksPerSm;
if (is_nhwc) {
if (top_zero) {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,true,true,false>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true,true,false>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true,true,false>, grid, block, kernelArgs, 0, current_stream);
#endif
} else if (btm_zero) {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,true,false,true>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true,false,true>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true,false,true>, grid, block, kernelArgs, 0, current_stream);
#endif
} else {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,true,false,false>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true,false,false>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true,false,false>, grid, block, kernelArgs, 0, current_stream);
#endif
}
} else {
if (top_zero) {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,false,true,false>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false,true,false>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false,true,false>, grid, block, kernelArgs, 0, current_stream);
#endif
} else if (btm_zero) {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,false,false,true>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false,false,true>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false,false,true>, grid, block, kernelArgs, 0, current_stream);
#endif
} else {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,false,false,false>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
#ifdef __HIP_PLATFORM_HCC__
hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false,false,false>, grid, block, kernelArgs, 0, current_stream);
#else
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false,false,false>, grid, block, kernelArgs, 0, current_stream);
#endif
}
}
}
} );
}
} } }
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <torch/extension.h>
#ifndef _peer_memory_h_
#define _peer_memory_h_
namespace apex { namespace contrib { namespace peer_memory {
int64_t allocate_raw(int64_t size);
void free_raw(int64_t raw);
void zero(int64_t raw, int64_t size);
at::Tensor get_raw_ipc_address(int64_t raw);
std::vector<int64_t> get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int64_t raw);
at::Tensor blob_view_half(int64_t raw, std::vector<int64_t> shape, bool channels_last);
at::Tensor blob_view_float(int64_t raw, std::vector<int64_t> shape, bool channels_last);
at::Tensor blob_view_int(int64_t raw, std::vector<int64_t> shape, bool channels_last);
void push_pull_halos_1d(
bool diagnostics,
bool explicit_nhwc,
int numSM, // number of SMs to use
bool top_zero, // true if top halo should be zeroed
at::Tensor top_out_halo, // top output halo in sender device memory
at::Tensor top_out_tx, // top output transfer buffer in sender peer pool memory
at::Tensor top_inp_tx, // top input transfer buffer in top neighbor peer pool memory
at::Tensor top_inp_halo, // top input halo in receiver device memory
bool btm_zero, // true if btm halo should be zeroed
at::Tensor btm_out_halo, // btm output halo in sender device memory
at::Tensor btm_out_tx, // btm output transfer buffer in sender peer pool memory
at::Tensor btm_inp_tx, // btm input transfer buffer in btm neighbor peer pool memory
at::Tensor btm_inp_halo, // btm input halo in receiver device memory
at::Tensor top_signal, // top input signal in receiver device memory
at::Tensor btm_signal, // btm input signal in receiver device memory
at::Tensor waits // top and btm signals for this rank
);
} } }
#endif
#include <torch/extension.h>
#include <ATen/Functions.h>
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> transducer_joint_cuda_forward(
torch::Tensor f,
torch::Tensor g,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int64_t packedBatch,
int opt,
bool packOutput,
bool relu,
bool dropout,
float dropoutProb,
int tileSize);
std::vector<torch::Tensor> transducer_joint_cuda_backward(
std::vector<torch::Tensor> in,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int maxFLen,
int maxGLen,
bool packOutput,
float scale);
std::vector<torch::Tensor> transducer_joint_forward(
torch::Tensor f,
torch::Tensor g,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int64_t packedBatch,
int opt,
bool packOutput,
bool relu,
bool dropout,
float dropoutProb,
int tileSize) {
CHECK_INPUT(f);
CHECK_INPUT(g);
CHECK_INPUT(fLen);
CHECK_INPUT(gLen);
if (packOutput)
CHECK_INPUT(batchOffset);
return transducer_joint_cuda_forward(
f,
g,
fLen,
gLen,
batchOffset,
packedBatch,
opt,
packOutput,
relu,
dropout,
dropoutProb,
tileSize);
}
std::vector<torch::Tensor> transducer_joint_backward(
std::vector<torch::Tensor> in,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int maxFLen,
int maxGLen,
bool packOutput,
float scale) {
for (auto t : in){
CHECK_INPUT(t);
}
CHECK_INPUT(fLen);
CHECK_INPUT(gLen);
if (packOutput)
CHECK_INPUT(batchOffset);
return transducer_joint_cuda_backward(
in,
fLen,
gLen,
batchOffset,
maxFLen,
maxGLen,
packOutput,
scale);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &transducer_joint_forward, "transducer joint forward (CUDA)");
m.def("backward", &transducer_joint_backward, "transducer joint backward (CUDA)");
}
\ No newline at end of file
#include <cuda.h>
#include <curand_kernel.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <ATen/AccumulateType.h>
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <c10/macros/Macros.h>
#include "philox.cuh"
#ifdef __HIP_PLATFORM_HCC__
#define SHFL_DOWN(val, laneMask, width) __shfl_down(val, laneMask, width)
#else
#define SHFL_DOWN(val, laneMask, width) __shfl_down_sync(0xffffffff, val, laneMask, width)
#endif
// Warp reduce kernels to reduce N groups of data into N numbers, where N = warpSize / width.
// width should be a power of 2 and should be less than warpSize.
template <typename scalar_t>
__device__ __forceinline__ scalar_t warpReduce(scalar_t x, int width=C10_WARP_SIZE){
for (unsigned offset = width/2; offset > 0; offset /= 2){
x += SHFL_DOWN(x, offset, width);
}
return x;
}
inline int largestPowerOfTwo(int x){
int y = 1;
while (y <= x)
y <<= 1;
return y >> 1;
}
/*
Figure out vectorization type for masks.
Similar to how PyTorch figures out acc_t here:
aten/src/ATen/AccumulateType.h
*/
template <int V>
struct MaskVecType { };
template <> struct MaskVecType<1> { using type = uint8_t; };
template <> struct MaskVecType<2> { using type = uint16_t; };
template <> struct MaskVecType<4> { using type = uint32_t; };
template<int V>
using mvec_type = typename MaskVecType<V>::type;
// Helper class to calculate pointer offset that can be shared by different flavors of kernels.
// For fwd, batch offset and stride are different for packing and non-packing mode.
struct OffsetCalFwd{
__device__ __forceinline__ OffsetCalFwd(
int64_t batch,
const int64_t *batchOffset,
int64_t maxFLen,
int64_t maxGLen,
int64_t gLen,
int64_t hiddenSize,
bool packOutput) :
batch(batch),
batchOffset(batchOffset),
maxFLen(maxFLen),
maxGLen(maxGLen),
gLen(gLen),
hiddenSize(hiddenSize),
packOutput(packOutput)
{}
int64_t batch;
const int64_t *batchOffset;
int64_t maxFLen;
int64_t maxGLen;
int64_t gLen;
int64_t hiddenSize;
bool packOutput;
__device__ __forceinline__ int64_t getBatchOffset(){
return packOutput ? ((batch==0) ? 0 : batchOffset[batch-1])*hiddenSize
: batch*maxFLen*maxGLen*hiddenSize;
}
__device__ __forceinline__ int64_t getStrideF(){
return packOutput ? gLen*hiddenSize : maxGLen*hiddenSize;
}
};
// Helper class to calculate pointer offset that can be shared by different flavors of kernels
// For bwd, batch offset and stride are different for packing and non-packing mode.
// The reducion is done for two input tensors. Therefore, generating two sets of offsets
// according to bwdFasterDim can lead to a unified implementation in the actual kernel.
struct OffsetCalBwd{
__device__ __forceinline__ OffsetCalBwd(
int64_t batch,
const int64_t *batchOffset,
const int *fLen,
const int *gLen,
int64_t maxFLen,
int64_t maxGLen,
int64_t hiddenSize,
bool packOutput,
bool bwdFasterDim) :
batch(batch),
batchOffset(batchOffset),
maxFLen(maxFLen),
maxGLen(maxGLen),
fLen(fLen),
gLen(gLen),
hiddenSize(hiddenSize),
packOutput(packOutput),
bwdFasterDim(bwdFasterDim)
{}
int64_t batch;
const int64_t *batchOffset;
const int *fLen;
const int *gLen;
int64_t maxFLen;
int64_t maxGLen;
int64_t hiddenSize;
bool packOutput;
bool bwdFasterDim; // whether doing bwd on the faster moving dimension
__device__ __forceinline__ int64_t getBatchOffset(){
return packOutput ? ((batch==0) ? 0 : batchOffset[batch-1])*hiddenSize
: batch*maxFLen*maxGLen*hiddenSize;
}
__device__ __forceinline__ int64_t getMaxXLen(){
return bwdFasterDim ? maxGLen : maxFLen;
}
__device__ __forceinline__ auto getMyXLen() -> decltype(gLen[batch]){
return bwdFasterDim ? gLen[batch] : fLen[batch];
}
__device__ __forceinline__ auto getMyYLen() -> decltype(gLen[batch]){
return bwdFasterDim ? fLen[batch] : gLen[batch];
}
__device__ __forceinline__ int64_t getStrideX(){
return bwdFasterDim ? hiddenSize : ((packOutput ? gLen[batch] : maxGLen) * hiddenSize);
}
__device__ __forceinline__ int64_t getStrideY(){
return bwdFasterDim ? ((packOutput ? gLen[batch] : maxGLen) * hiddenSize) : hiddenSize;
}
};
// Vanila transducer joint forward kernel
// Detail of this joint function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// f is a tensor of shape [batch, T, H]
// g is a tensor of shape [batch, U, H]
// the transducer joint does
// sum = f.unsqueeze(dim=2) + g.unsqueeze(dim=1)
// The resultant tensor is of shape [batch, T, U, H]
// Each thread block is working on one "batch" of data in the output tensor, [batch, t, u, :]
// This joint function can optionally pack the output where the output tensor with a shape of
// [B, T, U, H] is packed into [B_packed, H].
// Don't-care region (t > fLen) or (u > gLen) is removed.
// To enable packing, the starting offset for each batch need to be specified with batchOffset.
template <typename scalar_t, class OffsetCal>
__global__ void transducer_joint_forward(
const scalar_t *f,
const scalar_t *g,
const int *fLen,
const int *gLen,
const int64_t *batchOffset,
int64_t maxFLen,
int64_t maxGLen,
int64_t hiddenSize,
bool packOutput,
scalar_t *sum) {
const int batch = blockIdx.z;
const int t = blockIdx.y;
const int u = blockIdx.x;
const auto myFLen = fLen[batch];
const auto myGLen = gLen[batch];
OffsetCal offsetCal(batch, batchOffset, maxFLen, maxGLen, myGLen, hiddenSize, packOutput);
const auto myBatchOffset = offsetCal.getBatchOffset();
const auto strideF = offsetCal.getStrideF();
scalar_t const *myF = f + batch*maxFLen*hiddenSize + t*hiddenSize;
scalar_t const *myG = g + batch*maxGLen*hiddenSize + u*hiddenSize;
scalar_t *mySum = sum + myBatchOffset + t*strideF + u * hiddenSize;
if (t < myFLen and u < myGLen){
#pragma unroll
for (int h = threadIdx.x; h < hiddenSize; h += blockDim.x){
if (h < hiddenSize){
mySum[h] = myF[h] + myG[h];
}
}
}
else if (packOutput == false and t < maxFLen and u < maxGLen){
// Need to write finite data to don't-care region because we instantiate the result tensor
// with torch::empty for performance reasons. Even though it is don't-care region, the
// contents need to be finite, otherwise could lead to NaN in WGRAD.
// In packing mode, this write is no longer necessary as we remove the don't-care region
// from the output.
// Picking -1 (over 0) here for ease of testing.
#pragma unroll
for (int h = threadIdx.x; h < hiddenSize; h += blockDim.x){
if (h < hiddenSize){
mySum[h] = -1;
}
}
}
}
/*
Tiled version of the joint forward kernel
Detail of this joint function can be found in:
[1] Sequence Transduction with Recurrent Neural Networks.
f is a tensor of shape [batch, T, H]
g is a tensor of shape [batch, U, H]
the transducer joint does
sum = f.unsqueeze(dim=2) + g.unsqueeze(dim=1)
The resultant tensor is of shape [batch, T, U, H]
Each thread is working on a tile of the shape of tileF x tileG in the result tensor.
The input for the tile is first loaded in the register and is reused tileG and tileF times.
This joint function can optionally pack the output where the output tensor with a shape of
[B, T, U, H] is packed into [B_packed, H].
Don't-care region (t > fLen) or (u > gLen) is removed.
To enable packing, the starting offset for each batch need to be specified with batchOffset.
Optionally this joint function performs ReLU and/or dropout on the joint output, which is
controlled by arguments relu and dropout, respectively. philoxArgs is argument used for generating
pseudorandom number. When at least one of operations in ReLU and dropout is activated, the joint
function is a masked operation, which is controlled by the template argument masked. In this case,
masks are saved to backward.
*/
template <typename scalar_t, int tileF, int tileG, int U, class OffsetCal, bool masked>
__global__ void transducer_joint_tiled_forward(
const scalar_t *f,
const scalar_t *g,
const int *fLen,
const int *gLen,
const int64_t *batchOffset,
int64_t maxFLen,
int64_t maxGLen,
int64_t hiddenSize,
int64_t hiddenPerBlock,
bool packOutput,
bool relu,
bool dropout,
float p,
at::PhiloxCudaState philoxArgs,
scalar_t *sum,
uint8_t *mask) {
static_assert(U == 4, "U has to be 4, as random numbers are generated in batch of 4");
const int batch = blockIdx.z;
const int t = blockIdx.y * tileF;
const int hiddenBlock = (hiddenSize + hiddenPerBlock - 1) / hiddenPerBlock;
const int u = blockIdx.x / hiddenBlock * tileG;
const int hOffset = (blockIdx.x % hiddenBlock) * hiddenPerBlock;
const int h = threadIdx.x;
const auto myFLen = fLen[batch];
const auto myGLen = gLen[batch];
OffsetCal offsetCal(batch, batchOffset, maxFLen, maxGLen, myGLen, hiddenSize, packOutput);
const auto myBatchOffset = offsetCal.getBatchOffset();
const auto strideF = offsetCal.getStrideF();
scalar_t const *myF = f + batch*maxFLen*hiddenSize + t*hiddenSize + hOffset;
scalar_t const *myG = g + batch*maxGLen*hiddenSize + u*hiddenSize + hOffset;
scalar_t *mySum = sum + myBatchOffset + t*strideF + u*hiddenSize + hOffset;
uint8_t *myMask = mask + myBatchOffset + t*strideF + u*hiddenSize + hOffset;
// The following code is only needed for dropout. We try to bypass them as much as possible.
auto seeds = masked ? at::cuda::philox::unpack(philoxArgs)
: std::make_tuple(static_cast<uint64_t>(0), static_cast<uint64_t>(0));
uint64_t tid = masked ? (static_cast<uint64_t>(blockIdx.z)*gridDim.y*gridDim.x +
blockIdx.y*gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x
: 0;
Philox ph(std::get<0>(seeds), tid, std::get<1>(seeds));
scalar_t scale = masked ? ((p == 0) ? 0 : 1 / p) : 0;
bool dropoutMask[U];
if (t < myFLen and u < myGLen and hOffset+h < hiddenSize){
// register buffers for tiled input reuse
scalar_t fBuffer[tileF], gBuffer[tileG];
for (int i = 0; i < tileF; ++i){
if (t + i < myFLen)
fBuffer[i] = myF[i*hiddenSize + h];
}
for (int j = 0; j < tileG; ++j){
if (u + j < myGLen)
gBuffer[j] = myG[j*hiddenSize + h];
}
#pragma unroll
for (int i = 0; i < tileF; ++i){
if (t + i < myFLen){
#pragma unroll
for (int j = 0; j < tileG; ++j){
int idx = i*tileG + j;
if (masked and dropout and idx % U == 0){
// For performance, generate 4 random numbers in one shot
// auto rand4 = curand_uniform4(&state);
auto rand4 = uniform4(ph());
dropoutMask[0] = rand4.x < p;
dropoutMask[1] = rand4.y < p;
dropoutMask[2] = rand4.z < p;
dropoutMask[3] = rand4.w < p;
}
if (u + j < myGLen){
scalar_t out = fBuffer[i] + gBuffer[j];
if (masked){
// Apply ReLU here when relu is True
bool localMask = relu ? (out>0) : 1;
localMask = dropout ? localMask & dropoutMask[idx%U] : localMask;
out = dropout ? out*localMask*scale : out*localMask;
myMask[i*strideF + j*hiddenSize + h] = static_cast<uint8_t>(localMask);
}
mySum[i*strideF + j*hiddenSize + h] = out;
}
else if (packOutput == false and u + j < maxGLen)
mySum[i*strideF + j*hiddenSize + h] = -1;
}
}
else if (packOutput == false and t + i < maxFLen){
// Again need to write finite data to don't-care region
#pragma unroll
for (int j = 0; j < tileG; ++j){
if (u + j < maxGLen)
mySum[i*strideF + j*hiddenSize + h] = -1;
}
}
}
}
else if (packOutput == false and t < maxFLen and u < maxGLen and hOffset+h < hiddenSize){
// Only need to ensure the finity in normal mode
#pragma unroll
for (int i = 0; i < tileF; ++i){
if (t + i < maxFLen){
#pragma unroll
for (int j = 0; j < tileG; ++j){
if (u + j < maxGLen)
mySum[i*strideF + j*hiddenSize + h] = -1;
}
}
}
}
}
/*
Bwd operation (reduction) on one input tensor. Since the operation performed for the two input
tensors are exactly the same, only one kernel is needed, and the different indexing offsets
and strides are handled by OffsetCalBwd.
When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a
non-packed form.
When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,
and mask contains the mask information.
*/
template <typename scalar_t, typename acc_t, class OffsetCal, bool masked>
__device__ void transducer_joint_single_backward(
const scalar_t *grad,
const uint8_t *mask,
const int *fLen,
const int *gLen,
const int64_t *batchOffset,
int64_t maxFLen,
int64_t maxGLen,
int64_t hiddenSize,
bool packOutput,
bool bwdFasterDim, // whether bwd on the faster moving dimension (u)
float scale,
scalar_t *inGrad,
int yBlockOffset=0) {
const int batch = blockIdx.z;
// For the second input tensor, this offset need to be subtracted because the first yBlockOffset
// sets of thread blocks are for the first input tensor.
const int x = blockIdx.y-yBlockOffset;
const int hOffset = blockIdx.x*C10_WARP_SIZE;
const int wid = threadIdx.y;
const int lid = threadIdx.x;
const int numWarp = blockDim.y;
extern __shared__ char smem8[];
auto smem = reinterpret_cast<acc_t*>(smem8);
OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput,
bwdFasterDim);
const auto maxXLen = offsetCal.getMaxXLen();
const auto myXLen = offsetCal.getMyXLen();
const auto myYLen = offsetCal.getMyYLen();
scalar_t *myInGrad = inGrad + batch*maxXLen*hiddenSize + x*hiddenSize + hOffset;
if (x < myXLen){
const auto myBatchOffset = offsetCal.getBatchOffset();
const auto strideX = offsetCal.getStrideX();
const auto strideY = offsetCal.getStrideY();
const scalar_t *myGrad = grad + myBatchOffset + x*strideX + hOffset;
const uint8_t *myMask = masked ? mask + myBatchOffset + x*strideX + hOffset : nullptr;
// Each warp reduces numYPerWarp "y" first
acc_t warpSum = 0;
auto numYPerWarp = (myYLen+numWarp-1)/numWarp;
#pragma unroll
for (int warpY = 0; warpY < numYPerWarp; ++warpY){
auto y = wid*numYPerWarp + warpY;
if (y < myYLen and (hOffset+lid) < hiddenSize)
if (masked)
warpSum += static_cast<acc_t>(myGrad[y*strideY + lid]) * myMask[y*strideY + lid] * scale;
else
warpSum += myGrad[y*strideY + lid];
}
// transpose partial sum in SMEM and reduce further using warpReduce
smem[lid*numWarp + wid] = warpSum;
__syncthreads();
auto sum = smem[wid*C10_WARP_SIZE + lid];
sum = warpReduce(sum, numWarp);
// a a b b c c d d
// a a b b c c d d
// a a b b c c d d
// a a b b c c d d
// example of 4 warps (a, b, c, d) with 8 threads per warp
// Each warp need 8 / 4 = 2 threads to write the results.
if (hOffset+wid*C10_WARP_SIZE/numWarp+lid/numWarp < hiddenSize){
if (lid % numWarp == 0){
myInGrad[wid*C10_WARP_SIZE/numWarp + lid/numWarp] = sum;
}
}
}
else if (wid == 0 and hOffset + lid < hiddenSize){
// Need to ensure the grad is zero for don't care region
myInGrad[lid] = 0;
}
}
/*
Actual bwd (reduction) kernel get launched.
Call transducer_joint_single_backward twice on two input tensors.
The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op
uses the rest.
When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,
and mask contains the mask information.
*/
template <typename scalar_t, typename acc_t, class OffsetCal, bool masked>
__global__ void transducer_joint_combined_backward(
const scalar_t *grad,
const uint8_t *mask,
const int *fLen,
const int *gLen,
const int64_t *batchOffset,
int64_t maxFLen,
int64_t maxGLen,
int64_t hiddenSize,
bool packOutput,
float scale,
scalar_t *fGrad,
scalar_t *gGrad) {
if (blockIdx.y < maxFLen){
transducer_joint_single_backward<scalar_t, acc_t, OffsetCal, masked>(
grad,
mask,
fLen,
gLen,
batchOffset,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
false,
scale,
fGrad);
}
else{
transducer_joint_single_backward<scalar_t, acc_t, OffsetCal, masked>(
grad,
mask,
fLen,
gLen,
batchOffset,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
true,
scale,
gGrad,
maxFLen);
}
}
/*
Vectorized version of transducer_joint_single_backward
Doing exact same operation as transducer_joint_single_backward except the load and store are
vectorized.
When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a
non-packed form.
When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,
and mask contains the mask information.
*/
template <typename scalar_t, typename acc_t, typename vec_t, int V, class OffsetCal, bool masked>
__device__ void transducer_joint_single_vec_backward(
const scalar_t *grad,
const uint8_t *mask,
const int *fLen,
const int *gLen,
const int64_t *batchOffset,
int64_t maxFLen,
int64_t maxGLen,
int64_t hiddenSize,
bool packOutput,
bool bwdFasterDim,
float scale,
scalar_t *inGrad,
int yBlockOffset=0){
const int batch = blockIdx.z;
const int x = blockIdx.y - yBlockOffset;
const int hOffset = blockIdx.x*C10_WARP_SIZE*V;
const int wid = threadIdx.y;
const int lid = threadIdx.x;
const int numWarp = blockDim.y;
// Figure out the vectorization type for mask
using mvec_t = mvec_type<V>;
OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput,
bwdFasterDim);
const auto maxXLen = offsetCal.getMaxXLen();
const auto myXLen = offsetCal.getMyXLen();
const auto myYLen = offsetCal.getMyYLen();
scalar_t *myInGrad = inGrad + batch*maxXLen*hiddenSize + x*hiddenSize + hOffset;
extern __shared__ char smem8[];
auto smem = reinterpret_cast<acc_t*>(smem8);
acc_t warpSum[V];
scalar_t inBuffer[V];
uint8_t maskBuffer[V];
scalar_t outBuffer[V];
auto myInGradVec = reinterpret_cast<vec_t*>(myInGrad);
auto outBufferVec = reinterpret_cast<vec_t*>(outBuffer);
if (x < myXLen){
const auto myBatchOffset = offsetCal.getBatchOffset();
const auto strideX = offsetCal.getStrideX();
const auto strideY = offsetCal.getStrideY();
const scalar_t *myGrad = grad + myBatchOffset + x*strideX + hOffset;
const uint8_t *myMask = masked ? mask + myBatchOffset + x*strideX + hOffset
:nullptr;
for (int i = 0; i < V; ++i)
warpSum[i] = 0;
// Each warp reduces numYPerWarp "y" first
auto numYPerWarp = (myYLen+numWarp-1)/numWarp;
for (int warpY = 0; warpY < numYPerWarp; ++warpY){
auto y = wid*numYPerWarp + warpY;
auto myGradVec = reinterpret_cast<vec_t const *>(myGrad + y*strideY);
auto myMaskVec = masked ? reinterpret_cast<mvec_t const *>(myMask + y*strideY)
: nullptr;
auto inBufferVec = reinterpret_cast<vec_t*>(inBuffer);
auto maskBufferVec = reinterpret_cast<mvec_t*>(maskBuffer);
if (hOffset + lid*V < hiddenSize and y < myYLen){
*inBufferVec = myGradVec[lid]; // vectorized load
if (masked){
*maskBufferVec = myMaskVec[lid];
#pragma unroll
for (int i = 0; i < V; ++i)
warpSum[i] += static_cast<acc_t>(inBuffer[i]) * maskBuffer[i] * scale;
}
else{
#pragma unroll
for (int i = 0; i < V; ++i)
warpSum[i] += inBuffer[i];
}
}
}
// transpose partial sum in SMEM and reduce further using warpReduce
for (int i = 0; i < V; ++i){
smem[lid*numWarp + wid] = warpSum[i];
__syncthreads();
auto sum = smem[wid*C10_WARP_SIZE + lid];
if (hOffset+(wid*C10_WARP_SIZE/numWarp)*V < hiddenSize){
sum = warpReduce(sum, numWarp);
if (lid % numWarp == 0){
outBuffer[i] = sum;
}
}
__syncthreads();
}
// a a b b c c d d
// a a b b c c d d
// a a b b c c d d
// a a b b c c d d
// example of 4 warps (a, b, c, d) with 8 threads per warp
// Each warp need 8 / 4 = 2 threads to write the results.
if (lid % numWarp == 0 and hOffset+(wid*C10_WARP_SIZE/numWarp + lid/numWarp)*V < hiddenSize)
myInGradVec[wid*C10_WARP_SIZE/numWarp + lid/numWarp] = *outBufferVec;
}
else if (wid == 0 and hOffset + lid*V < hiddenSize){
// Need to ensure the grad is zero for don't care region
myInGradVec[lid] = 0;
}
}
/*
Vecotrized version of transducer_joint_combined_backward
Call transducer_joint_single_vec_backward twice on two input tensors.
The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op
uses the rest.
When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation,
and mask contains the mask information.
*/
template <typename scalar_t, typename acc_t, typename vec_t, int V, class OffsetCal, bool masked>
__global__ void transducer_joint_combined_vec_backward(
const scalar_t *grad,
const uint8_t *mask,
const int *fLen,
const int *gLen,
const int64_t *batchOffset,
int64_t maxFLen,
int64_t maxGLen,
int64_t hiddenSize,
bool packOutput,
float scale,
scalar_t *fGrad,
scalar_t *gGrad) {
if (blockIdx.y < maxFLen){
transducer_joint_single_vec_backward<scalar_t, acc_t, vec_t, V, OffsetCal, masked>(
grad,
mask,
fLen,
gLen,
batchOffset,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
false,
scale,
fGrad);
}
else{
transducer_joint_single_vec_backward<scalar_t, acc_t, vec_t, V, OffsetCal, masked>(
grad,
mask,
fLen,
gLen,
batchOffset,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
true,
scale,
gGrad,
maxFLen);
}
}
std::vector<torch::Tensor> transducer_joint_cuda_forward(
torch::Tensor f,
torch::Tensor g,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int64_t packedBatch,
int opt,
bool packOutput,
bool relu,
bool dropout,
float dropoutProb,
int tileSize){
auto tensorOpt = f.options();
auto dtype = f.scalar_type();
const auto batchSize = f.size(0);
const auto maxFLen = f.size(1);
const auto maxGLen = g.size(1);
const auto hiddenSize = f.size(2);
bool masked = dropout or relu;
int64_t *batchOffsetPtr = nullptr;
torch::Tensor sum, mask;
auto maskOpt = tensorOpt.dtype(torch::kUInt8);
if (!packOutput){
sum = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, tensorOpt);
batchOffsetPtr = nullptr;
if (masked)
mask = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, maskOpt);
}
else{
sum = torch::empty({packedBatch, hiddenSize}, tensorOpt);
batchOffsetPtr = batchOffset.data_ptr<int64_t>();
if (masked)
mask = torch::empty({packedBatch, hiddenSize}, maskOpt);
}
uint8_t *maskPtr = masked ? mask.data_ptr<uint8_t>() : nullptr;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TORCH_CHECK(opt == 0 or opt == 1, "Got an invalid optimization level ", opt);
// Simple heuristics
const int numThread = std::min(128, (static_cast<int>(hiddenSize)+C10_WARP_SIZE-1)
/ C10_WARP_SIZE * C10_WARP_SIZE);
if (opt == 0){
// vanilla kernel
const int threads = numThread;
const dim3 blocks(maxGLen, maxFLen, batchSize);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_forward", ([&] {
transducer_joint_forward<scalar_t, OffsetCalFwd>
<<<blocks, threads, 0, stream>>>(
f.data_ptr<scalar_t>(),
g.data_ptr<scalar_t>(),
fLen.data_ptr<int>(),
gLen.data_ptr<int>(),
batchOffsetPtr,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
sum.data_ptr<scalar_t>());
}));
}
if (opt == 1){
// tiled version. For simplicity, assume tileF == tileG, even though the kernel can
// support more general cases.
const int threads = numThread;
const int hiddenPerBlock = numThread;
const int hiddenBlock = (hiddenSize + hiddenPerBlock - 1) / hiddenPerBlock;
const dim3 blocks( (maxGLen+tileSize-1)/tileSize * hiddenBlock,
(maxFLen+tileSize-1)/tileSize,
batchSize);
TORCH_CHECK(tileSize == 1 or tileSize == 2 or tileSize == 4,
"Expected tileSize to be in [1, 2, 4], but got ", tileSize);
at::PhiloxCudaState rng_engine_inputs;
if (masked){
// set up PRG when the input is masked. rng_engine_inputs will be used as a space filler
// for non-masked calls.
// Therefore no need to initialize.
c10::optional<at::Generator> gen_;
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_,
at::cuda::detail::getDefaultCUDAGenerator());
// counterOffset records how many cuRAND calls each thread makes. For a tiled kernel,
// each thread processes tileF * tileG output elements.
int64_t counterOffset = tileSize * tileSize;
{
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_cuda_state(counterOffset);
}
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_forward", ([&] {
void(*kernel)(const scalar_t*, const scalar_t*, const int*, const int*, const int64_t*,
int64_t, int64_t, int64_t, int64_t, bool, bool, bool, float,
at::PhiloxCudaState, scalar_t*, uint8_t*);
if (masked){
switch (tileSize){
case 2:
kernel = &transducer_joint_tiled_forward<scalar_t, 2, 2, 4, OffsetCalFwd,
true>;
break;
case 4:
kernel = &transducer_joint_tiled_forward<scalar_t, 4, 4, 4, OffsetCalFwd,
true>;
break;
}
}
else{
switch (tileSize){
case 1:
kernel = &transducer_joint_tiled_forward<scalar_t, 1, 1, 4, OffsetCalFwd,
false>;
break;
case 2:
kernel = &transducer_joint_tiled_forward<scalar_t, 2, 2, 4, OffsetCalFwd,
false>;
break;
case 4:
kernel = &transducer_joint_tiled_forward<scalar_t, 4, 4, 4, OffsetCalFwd,
false>;
break;
}
}
kernel<<<blocks, threads, 0, stream>>>(
f.data_ptr<scalar_t>(),
g.data_ptr<scalar_t>(),
fLen.data_ptr<int>(),
gLen.data_ptr<int>(),
batchOffsetPtr,
maxFLen,
maxGLen,
hiddenSize,
hiddenPerBlock,
packOutput,
relu,
dropout,
1.0f - dropoutProb,
rng_engine_inputs,
sum.data_ptr<scalar_t>(),
maskPtr);
}));
}
C10_CUDA_CHECK(cudaGetLastError());
if (masked)
return {sum, mask};
else
return {sum};
}
std::vector<torch::Tensor> transducer_joint_cuda_backward(
std::vector<torch::Tensor> in,
torch::Tensor fLen,
torch::Tensor gLen,
torch::Tensor batchOffset,
int maxFLen,
int maxGLen,
bool packOutput,
float scale){
auto grad = in[0];
bool masked = (in.size() == 2);
uint8_t *maskPtr = masked ? in[1].data_ptr<uint8_t>() : nullptr;
auto tensorOpt = grad.options();
auto dtype = grad.scalar_type();
const int batchSize = fLen.size(0);
const int hiddenSize = grad.size(-1);
const auto deviceProperties = at::cuda::getCurrentDeviceProperties();
const int maxNumWarp = deviceProperties->maxThreadsPerBlock / C10_WARP_SIZE;
torch::Tensor fGrad = torch::empty({batchSize, maxFLen, hiddenSize}, tensorOpt);
torch::Tensor gGrad = torch::empty({batchSize, maxGLen, hiddenSize}, tensorOpt);
int64_t *batchOffsetPtr = (!packOutput) ? nullptr : batchOffset.data_ptr<int64_t>();
// The number "y" I would like each thread to work on
const int workPerThread = 32;
// Since the bwd for f and g have the same thread block size, we need to use the max of the two.
int numWarp = largestPowerOfTwo((std::max(maxFLen, maxGLen) + workPerThread-1) / workPerThread);
// Would like to have at least 2 warps
numWarp = std::max(2, numWarp);
// cap on the maximum number of warps allowed
numWarp = std::min(maxNumWarp, numWarp);
// Need smem for transposing the partial sum. The partial sum is in a matrix of the shape
// numWarp x warpSize
const int smemSize = numWarp * C10_WARP_SIZE;
const dim3 threads(C10_WARP_SIZE, numWarp, 1);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_cuda_backward_kernel", ([&] {
auto gradPtr = grad.data_ptr<scalar_t>();
auto fLenPtr = fLen.data_ptr<int>();
auto gLenPtr = gLen.data_ptr<int>();
auto fGradPtr = fGrad.data_ptr<scalar_t>();
auto gGradPtr = gGrad.data_ptr<scalar_t>();
// resolve the acc_t type
using acc_t = at::acc_type<scalar_t, true>;
using vec_t = uint64_t;
constexpr int vectFactor = sizeof(vec_t) / sizeof(scalar_t);
constexpr int vecAlignment = std::alignment_of<vec_t>::value;
// if all input and output tensors meet the alignment requirement
bool memAlign = (reinterpret_cast<uint64_t>(gradPtr) % vecAlignment == 0)
and (reinterpret_cast<uint64_t>(fGradPtr) % vecAlignment == 0)
and (reinterpret_cast<uint64_t>(gGradPtr) % vecAlignment == 0);
if (vectFactor > 1 and hiddenSize%vectFactor == 0 and memAlign){
// If vectorization helps and the alignment requirement is met, use the vectorized
// kernel. For simplicity, hiddenSize needs to be a multiple vecFactor.
const dim3 blocks( (hiddenSize+C10_WARP_SIZE*vectFactor-1)/(C10_WARP_SIZE*vectFactor),
maxFLen+maxGLen,
batchSize);
if (masked){
transducer_joint_combined_vec_backward
<scalar_t, acc_t, vec_t, vectFactor, OffsetCalBwd, true>
<<<blocks, threads, smemSize*sizeof(acc_t)>>>(
gradPtr,
maskPtr,
fLenPtr,
gLenPtr,
batchOffsetPtr,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
scale,
fGradPtr,
gGradPtr);
}
else{
transducer_joint_combined_vec_backward
<scalar_t, acc_t, vec_t, vectFactor, OffsetCalBwd, false>
<<<blocks, threads, smemSize*sizeof(acc_t)>>>(
gradPtr,
maskPtr,
fLenPtr,
gLenPtr,
batchOffsetPtr,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
scale,
fGradPtr,
gGradPtr);
}
}
else{
const dim3 blocks((hiddenSize+C10_WARP_SIZE-1)/C10_WARP_SIZE,
maxFLen + maxGLen, batchSize);
if (masked){
transducer_joint_combined_backward<scalar_t, acc_t, OffsetCalBwd, true>
<<<blocks, threads, smemSize*sizeof(acc_t)>>>(
gradPtr,
maskPtr,
fLenPtr,
gLenPtr,
batchOffsetPtr,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
scale,
fGradPtr,
gGradPtr);
}
else{
transducer_joint_combined_backward<scalar_t, acc_t, OffsetCalBwd, false>
<<<blocks, threads, smemSize*sizeof(acc_t)>>>(
gradPtr,
maskPtr,
fLenPtr,
gLenPtr,
batchOffsetPtr,
maxFLen,
maxGLen,
hiddenSize,
packOutput,
scale,
fGradPtr,
gGradPtr);
}
}
}));
return {fGrad, gGrad};
}
#include <torch/extension.h>
#include <vector>
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> transducer_loss_cuda_forward(
torch::Tensor x,
torch::Tensor label,
torch::Tensor audLen,
torch::Tensor txtLen,
torch::Tensor batchOffset,
int maxFLen,
int blankIdx,
int opt,
bool packedInput);
torch::Tensor transducer_loss_cuda_backward(
torch::Tensor x,
torch::Tensor lossGrad,
torch::Tensor alpha,
torch::Tensor beta,
torch::Tensor audLen,
torch::Tensor txtLen,
torch::Tensor label,
torch::Tensor batchOffset,
int maxFLen,
int blankIdx,
int opt,
bool fuseSoftmaxBackward,
bool packedInput);
std::vector<torch::Tensor> transducer_loss_forward(
torch::Tensor x,
torch::Tensor label,
torch::Tensor fLen,
torch::Tensor yLen,
torch::Tensor batchOffset,
int maxFLen,
int blankIdx,
int opt,
bool packedInput
) {
CHECK_INPUT(x);
CHECK_INPUT(label);
CHECK_INPUT(fLen);
CHECK_INPUT(yLen);
if (packedInput)
CHECK_INPUT(batchOffset);
return transducer_loss_cuda_forward(
x,
label,
fLen,
yLen,
batchOffset,
maxFLen,
blankIdx,
opt,
packedInput);
}
torch::Tensor transducer_loss_backward(
torch::Tensor x,
torch::Tensor lossGrad,
torch::Tensor alpha,
torch::Tensor beta,
torch::Tensor fLen,
torch::Tensor yLen,
torch::Tensor label,
torch::Tensor batchOffset,
int maxFLen,
int blankIdx,
int opt,
bool fuseSoftmaxBackward,
bool packedInput){
CHECK_INPUT(x);
CHECK_INPUT(label);
CHECK_INPUT(lossGrad);
CHECK_INPUT(alpha);
CHECK_INPUT(beta);
CHECK_INPUT(fLen);
CHECK_INPUT(yLen);
if (packedInput)
CHECK_INPUT(batchOffset);
return transducer_loss_cuda_backward(
x,
lossGrad,
alpha,
beta,
fLen,
yLen,
label,
batchOffset,
maxFLen,
blankIdx,
opt,
fuseSoftmaxBackward,
packedInput);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &transducer_loss_forward, "transducer loss forward (CUDA)");
m.def("backward", &transducer_loss_backward, "transducer loss backward (CUDA)");
}
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
template<typename scalar_t>
__device__ __forceinline__ scalar_t logSumExp(scalar_t a, scalar_t b) {
// standard log-sum-exp trick is used here to provide better numerical stability
return (a >= b) ? a + std::log1p(exp(b-a)) : b + std::log1p(exp(a-b));
}
// Vanilla transducer loss function (i.e. forward-backward algorithm)
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// Forward (alpha) and backward (beta) path are launched together. Input is assumed to be converted
// into log scale by the preceding log_softmax layer
// Diagonal wavefront advancing usually used in dynamic programming is leveraged here.
// alpha and beta are of acc_t type, as they are essentially accumulators.
// This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into
// [B_packed, H].
// Don't-care region (t > audLen) or (u > txtLen) is removed.
// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template <typename scalar_t, typename acc_t>
__global__ void transducer_loss_forward(
const scalar_t* x,
const int* label,
const int* audLen,
const int* txtLen,
const int64_t* batchOffset,
int64_t dictSize, // 64-bit indexing for data tensor
int64_t blankIdx,
int64_t maxFLen,
int64_t maxGLen,
bool packedInput,
acc_t* alpha,
acc_t* beta,
scalar_t* loss) {
const int batch = blockIdx.y;
const int tid = threadIdx.x;
const auto myFLen = audLen[batch];
// Note that start of the sentence is added as 1 here
const auto myGLen = txtLen[batch] + 1;
const auto myLabel = label + batch * (maxGLen-1);
const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1])
: batch * maxFLen * maxGLen;
const int64_t myStrideT = packedInput ? myGLen : maxGLen;
const scalar_t* myX = x + myBatchOffset * dictSize;
int u = tid;
if (blockIdx.x == 0){
// alpha path
acc_t* myAlpha = alpha + batch*maxFLen*maxGLen;
if (u == 0)
myAlpha[0] = 0;
__syncthreads();
for (int64_t step = 1; step < myFLen+myGLen-1; ++step){
// Move along the diagonal wavefront to leverage available parallelism
for (u = tid; u < myGLen; u += blockDim.x){
int64_t t = step - u;
if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){
// Eq(16) in [1]
if (u == 0){
// alpha(t, u) = alpha(t-1, u) * null(t-1, u)
myAlpha[t*maxGLen + u] = myAlpha[(t-1)*maxGLen]
+ myX[((t-1)*myStrideT) * dictSize + blankIdx];
}
else if (t == 0){
// alpha(t, u-1) = alpha(t, u-1) * y(t, u-1)
myAlpha[u] = myAlpha[u - 1] + myX[(u - 1) * dictSize + myLabel[u - 1]];
}
else{
// alpha(t, u) = alpha(t-1, u) * null(t-1, u) + alpha(t, u-1) * y(t, u-1)
acc_t current = myAlpha[(t-1)*maxGLen + u]
+ myX[((t-1)*myStrideT + u) * dictSize + blankIdx];
acc_t next = myAlpha[t*maxGLen + u - 1]
+ myX[(t*myStrideT + u - 1) * dictSize + myLabel[u - 1]];
myAlpha[t*maxGLen + u] = logSumExp(next, current);
}
}
}
__syncthreads();
}
}
else if (blockIdx.x == 1){
// beta path
acc_t* myBeta = beta + batch*maxFLen*maxGLen;
if (u == 0){
myBeta[(myFLen-1)*maxGLen + myGLen - 1] = myX[((myFLen-1)*myStrideT
+ myGLen - 1) * dictSize + blankIdx];
}
__syncthreads();
for (int64_t step = myFLen+myGLen - 3; step >= 0; --step){
for (u = tid; u < myGLen; u += blockDim.x){
int64_t t = step - u;
if (t >= 0 and t < myFLen and u >=0 and u < myGLen){
// Eq(18) in [1]
if (u == myGLen - 1){
// beta(t, u) = beta(t+1, u) * null(t, u)
myBeta[t*maxGLen + u] = myBeta[(t+1)*maxGLen + u]
+ myX[(t*myStrideT + u) * dictSize + blankIdx];
}
else if (t == myFLen - 1){
// beta(t, u) = beta(t, u+1) * y(t, u)
myBeta[t*maxGLen + u] = myBeta[t*maxGLen + u + 1]
+ myX[(t*myStrideT + u) * dictSize + myLabel[u]];
}
else{
// beta(t, u) = beta(t+1, u)*null(t, u) + beta(t, u+1)*y(t, u)
acc_t current = myBeta[(t+1)*maxGLen + u]
+ myX[(t*myStrideT + u) * dictSize + blankIdx];
acc_t next = myBeta[t*maxGLen + u + 1]
+ myX[(t*myStrideT + u) * dictSize + myLabel[u]];
myBeta[t*maxGLen + u] = logSumExp(next, current);
}
}
}
__syncthreads();
}
if (tid == 0)
loss[batch] = -myBeta[0];
}
}
// transudcer loss function (i.e. forward-backward algorithm) with batch loading optimization.
// Compared to the vanilla version, there are two optimizations:
// 1. load x in batch through loop unrolling to reduce the latency.
// 2. Use registers and shared memory to hold alpha and beta values passed from one step the next.
// For simplicity, this kernel currently only supports U <= maxThread, which should be the common
// case. For cases where U > maxThread, the vanilla kernel is used as a fallback option.
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// Forward (alpha) and backward (beta) path are launched together. Input is assumed to be converted
// into log scale by the preceding log_softmax layer
// Diagonal wavefront advancing usually used in dynamic programming is leveraged here.
// alpha and beta are of acc_t type, as they are essentially accumulators.
// This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into
// [B_packed, H].
// Don't-care region (t > audLen) or (u > txtLen) is removed.
// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template <typename scalar_t, typename acc_t, int batchLdSize>
__global__ void transducer_loss_batch_load_forward(
const scalar_t* x,
const int* label,
const int* audLen,
const int* txtLen,
const int64_t* batchOffset,
int64_t dictSize,
int64_t blankIdx,
int64_t maxFLen,
int64_t maxGLen,
bool packedInput,
acc_t* alpha,
acc_t* beta,
scalar_t* loss) {
const int batch = blockIdx.y;
int u = threadIdx.x;
const auto myFLen = audLen[batch];
const auto myGLen = txtLen[batch] + 1;
const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1])
: batch * maxFLen * maxGLen;
const int64_t myStrideT = packedInput ? myGLen : maxGLen;
const scalar_t* myX = x + myBatchOffset * dictSize;
scalar_t next[batchLdSize], current[batchLdSize];
extern __shared__ char smem8[];
auto smem = reinterpret_cast<acc_t*>(smem8);
if (blockIdx.x == 0){
// alpha path
acc_t* myAlpha = alpha + batch*maxFLen*maxGLen;
// two SMEM regions for double buffering read and write data to avoid data race
acc_t * const sharedAlpha[2] = {smem, smem+maxGLen};
sharedAlpha[0][u] = 0;
__syncthreads();
if (u == 0)
myAlpha[0] = 0;
auto myAlphaLabel = (u == 0) ? 0 : label[batch*(maxGLen-1) + u - 1];
// register used to pass value to the next step for the same thread
acc_t prvStepAlpha = 0;
for (int64_t step = 1; step < myFLen+myGLen-1+batchLdSize; step += batchLdSize){
// Move along the diagonal wavefront to leverage available parallelism
// Batch loading X through loop unrolling
#pragma unroll
for (int i = 0; i < batchLdSize; ++i){
if (step+i<myFLen+myGLen-1){
// index computing
int64_t t = step + i - u;
int64_t currentId = ((t-1)*myStrideT + u) * dictSize + blankIdx;
int64_t nextId = (t*myStrideT + u - 1) * dictSize + myAlphaLabel;
// main loading loop
if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){
if (u == 0){
current[i] = myX[currentId];
}
else if (t == 0){
next[i] = myX[nextId];
}
else{
current[i] = myX[currentId];
next[i] = myX[nextId];
}
}
}
}
// main computing loop
for (int i = 0; i < batchLdSize; ++i){
// swap the pointer for double buffering
auto sharedAlphaRd = sharedAlpha[(step+i-1)%2];
auto sharedAlphaWr = sharedAlpha[(step+i)%2];
if (step+i<myFLen+myGLen-1){
int64_t t = step + i - u;
if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){
// Eq(16) in [1]
if (u == 0)
prvStepAlpha = prvStepAlpha+current[i];
else if (t == 0)
prvStepAlpha = sharedAlphaRd[u-1]+next[i];
else
prvStepAlpha = logSumExp(prvStepAlpha+current[i], sharedAlphaRd[u-1]
+ next[i]);
sharedAlphaWr[u] = prvStepAlpha;
myAlpha[t*maxGLen + u] = prvStepAlpha;
}
}
__syncthreads();
}
}
}
else if (blockIdx.x == 1){
// beta path
acc_t* myBeta = beta + batch*maxFLen*maxGLen;
// two SMEM regions for double buffering read and write data to avoid data race
acc_t * const sharedBeta[2] = {smem, smem + maxGLen};
sharedBeta[0][u] = myX[((myFLen-1)*myStrideT + myGLen - 1) * dictSize + blankIdx];
__syncthreads();
auto myBetaLabel = (u == maxGLen - 1) ? 0 : label[batch*(maxGLen-1) + u];
// register used to pass value to the next step for the same thread
acc_t prvStepBeta = myX[((myFLen-1)*myStrideT + myGLen - 1) * dictSize + blankIdx];
if (u == 0)
myBeta[(myFLen-1)*maxGLen + myGLen - 1] = prvStepBeta;
for (int64_t step = 1; step < myFLen+myGLen-1; step += batchLdSize){
// Move along the diagonal wavefront to leverage available parallelism
// Batch loading X
#pragma unroll
for (int i = 0; i < batchLdSize; ++i){
if (step+i<myFLen+myGLen-1){
// index computing
int64_t t = myFLen+myGLen - (step + i) - 2 - u;
int64_t currentId = (t*myStrideT + u) * dictSize + blankIdx;
int64_t nextId = (t*myStrideT + u) * dictSize + myBetaLabel;
// main loading loop
if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){
if (u == myGLen - 1){
current[i] = myX[currentId];
}
else if (t == myFLen - 1){
next[i] = myX[nextId];
}
else{
current[i] = myX[currentId];
next[i] = myX[nextId];
}
}
}
}
// main computing loop
for (int i = 0; i < batchLdSize; ++i){
// swap the pointer for double buffering
auto sharedBetaRd = sharedBeta[(step+i-1)%2];
auto sharedBetaWr = sharedBeta[(step+i)%2];
if (step+i<myFLen+myGLen-1){
int64_t t = myFLen+myGLen - (step + i) - 2 - u;
if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){
// Eq(18) in [1]
if (u == myGLen - 1)
prvStepBeta = prvStepBeta+current[i];
else if (t == myFLen - 1)
prvStepBeta = sharedBetaRd[u+1]+next[i];
else
prvStepBeta = logSumExp(prvStepBeta+current[i], sharedBetaRd[u+1]
+ next[i]);
sharedBetaWr[u] = prvStepBeta;
myBeta[t*maxGLen + u] = prvStepBeta;
}
}
__syncthreads();
}
}
if (u == 0)
loss[batch] = -prvStepBeta;
}
}
// Vanilla transudcer loss backward operation.
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// For this backward kernel, bwd op for the preceding softmax is assumed to be handled elsewhere,
// hence only Eq(20) in [1] is implemented in this kernel.
// Each thread block works on [batch, t, :, :] of data. Each thread works on a specific u at a time
// Since only gradients for the correct token and null token need to be updated, gradients at other
// locations are initialized to 0.
// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template <typename scalar_t, typename acc_t>
__global__ void transducer_loss_backward(
const scalar_t* x,
const scalar_t* lossGrad,
const int* audLen,
const int* txtLen,
const int* label,
const acc_t* alpha,
const acc_t* beta,
const int64_t* batchOffset,
int64_t dictSize,
int64_t blankIdx,
int64_t maxFLen,
int64_t maxGLen,
bool packedInput,
scalar_t* xGrad) {
const int tid = threadIdx.x;
const int t = blockIdx.x;
const int batch = blockIdx.y;
const int64_t myFLen = audLen[batch];
const int64_t myGLen = txtLen[batch] + 1;
const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1])
: batch * maxFLen * maxGLen;
const int64_t myStrideT = packedInput ? myGLen : maxGLen;
auto myX = x + (myBatchOffset + t*myStrideT)*dictSize;
auto myAlpha = alpha + batch*maxFLen*maxGLen;
auto myBeta = beta + batch*maxFLen*maxGLen;
auto myXGrad = xGrad + (myBatchOffset + t*myStrideT)*dictSize;
auto myLabel = label + batch*(maxGLen-1);
int64_t u = tid;
while (t < myFLen and u < myGLen){
// Do the update
// loss = -ln(Pr(y*|x))
acc_t grad = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0];
if (u != myGLen - 1)
myXGrad[u*dictSize + myLabel[u]] = -std::exp(grad + myBeta[t*maxGLen + u + 1]
+ myX[u*dictSize + myLabel[u]]);
if (t == myFLen - 1 and u == myGLen - 1)
myXGrad[u*dictSize + blankIdx] = -std::exp(grad + myX[u*dictSize + blankIdx]);
else if (t != myFLen - 1)
myXGrad[u*dictSize + blankIdx] = -std::exp(grad + myBeta[(t+1)*maxGLen + u]
+ myX[u*dictSize + blankIdx]);
u += blockDim.x;
}
}
// Fused transudcer loss backward operation.
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// The bwd op of the preceding softmax layer is fused in this kernel.
// Each thread block works on [batch, t, u, :] of data. Each thread works on a specific h at a time
// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template <typename scalar_t, typename acc_t>
__global__ void transducer_loss_fused_backward(
const scalar_t* x,
const scalar_t* lossGrad,
const int* audLen,
const int* txtLen,
const int* label,
const acc_t* alpha,
const acc_t* beta,
const int64_t* batchOffset,
int64_t dictSize,
int64_t blankIdx,
int64_t maxFLen,
int64_t maxGLen,
bool packedInput,
scalar_t* xGrad) {
const int tid = threadIdx.x;
const int u = blockIdx.x;
const int t = blockIdx.y;
const int batch = blockIdx.z;
const int64_t myFLen = audLen[batch];
const int64_t myGLen = txtLen[batch] + 1;
const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1])
: batch * maxFLen * maxGLen;
const int64_t myStrideT = packedInput ? myGLen : maxGLen;
__shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared;
auto myXGrad = xGrad + (myBatchOffset + t*myStrideT +u)*dictSize;
if (t < myFLen and u < myGLen){
auto myX = x + (myBatchOffset + t*myStrideT +u)*dictSize;
auto myAlpha = alpha + batch*maxFLen*maxGLen;
auto myBeta = beta + batch*maxFLen*maxGLen;
auto myLabel = label + batch*(maxGLen-1);
// load and store shared variables in SMEM
if (tid == 0){
commonFactor = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0];
myBetaTU = myBeta[t*maxGLen + u];
myBetaTUp1 = myBeta[t*maxGLen + u + 1];
myBetaTp1U = myBeta[(t+1)*maxGLen + u];
myLabelShared = myLabel[u];
}
__syncthreads();
for (int64_t h = tid; h < dictSize; h += blockDim.x){
// Do the update
acc_t grad = commonFactor + myX[h]; // loss = -ln(Pr(y*|x))
acc_t myGrad = std::exp(grad + myBetaTU);
if (u != myGLen - 1 and h == myLabelShared){
myGrad -= std::exp(grad + myBetaTUp1);
}
else if (h == blankIdx){
if (t == myFLen - 1 and u == myGLen - 1)
myGrad -= std::exp(grad);
else if (t != myFLen - 1)
myGrad -= std::exp(grad + myBetaTp1U);
}
myXGrad[h] = myGrad;
}
}
else if (!packedInput){
// In non-pack mode, need to make sure the gradients for don't-care regions are zero.
for (int64_t h = tid; h < dictSize; h += blockDim.x){
myXGrad[h] = 0;
}
}
}
// Vectorized version of fused transudcer loss backward operation.
// Detail of this loss function can be found in:
// [1] Sequence Transduction with Recurrent Neural Networks.
// The bwd op of the preceding softmax layer is fused in this kernel.
// Each thread block works on [batch, t, u, :] of data. Each thread works on a specific h at a time
// To support the packed input, the starting offsets for each batch need to be specified with
// batchOffset.
template <typename scalar_t, typename acc_t, typename vec_t, int V>
__global__ void transducer_loss_fused_vec_backward(
const scalar_t* x,
const scalar_t* lossGrad,
const int* audLen,
const int* txtLen,
const int* label,
const acc_t* alpha,
const acc_t* beta,
const int64_t* batchOffset,
int64_t dictSize,
int64_t blankIdx,
int64_t maxFLen,
int64_t maxGLen,
bool packedInput,
scalar_t* xGrad) {
const int tid = threadIdx.x;
const int u = blockIdx.x;
const int t = blockIdx.y;
const int batch = blockIdx.z;
const int64_t myFLen = audLen[batch];
const int64_t myGLen = txtLen[batch] + 1;
const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1])
: batch * maxFLen * maxGLen;
const int64_t myStrideT = packedInput ? myGLen : maxGLen;
__shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared;
auto myXGrad = xGrad + (myBatchOffset + t*myStrideT +u)*dictSize;
auto myX = x + (myBatchOffset + t*myStrideT +u)*dictSize;
auto myAlpha = alpha + batch*maxFLen*maxGLen;
auto myBeta = beta + batch*maxFLen*maxGLen;
auto myLabel = label + batch*(maxGLen-1);
// Variabels for vectorization
scalar_t myXBuffer[V], myXGradBuffer[V];
auto myXVec = reinterpret_cast<vec_t const *>(myX);
auto myXGradVec = reinterpret_cast<vec_t*>(myXGrad);
auto myXBufferVec = reinterpret_cast<vec_t*>(myXBuffer);
auto myXGradBufferVec = reinterpret_cast<vec_t*>(myXGradBuffer);
if (t < myFLen and u < myGLen){
// load and store shared variables in SMEM
if (tid == 0){
commonFactor = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0];
myBetaTU = myBeta[t*maxGLen + u];
if (t != myFLen - 1)
myBetaTp1U = myBeta[(t+1)*maxGLen + u];
if (u != myGLen - 1){
myBetaTUp1 = myBeta[t*maxGLen + u + 1];
myLabelShared = myLabel[u];
}
}
__syncthreads();
#pragma unroll
for (int64_t h0 = tid*V; h0 < dictSize; h0 += blockDim.x*V){
// Load myX in a vector form
*myXBufferVec = myXVec[h0/V];
// Do the update for a vector of input
#pragma unroll
for (int i = 0; i < V; ++i){
auto h = h0 + i;
acc_t grad = commonFactor + myXBuffer[i]; // loss = -ln(Pr(y*|x))
acc_t myGrad = std::exp(grad + myBetaTU);
if (u != myGLen - 1 and h == myLabelShared){
myGrad -= std::exp(grad + myBetaTUp1);
}
else if (h == blankIdx){
if (t == myFLen - 1 and u == myGLen - 1)
myGrad -= std::exp(grad);
else if (t != myFLen - 1)
myGrad -= std::exp(grad + myBetaTp1U);
}
myXGradBuffer[i] = myGrad;
}
// Store myXGrad in a vector form
myXGradVec[h0/V] = *myXGradBufferVec;
}
}
else if (!packedInput){
// In non-pack mode, need to make sure the gradients for don't-care regions are zero.
for (int64_t h0 = tid*V; h0 < dictSize; h0 += blockDim.x*V){
myXGradVec[h0/V] = 0;
}
}
}
std::vector<torch::Tensor> transducer_loss_cuda_forward(
torch::Tensor x,
torch::Tensor label,
torch::Tensor audLen,
torch::Tensor txtLen,
torch::Tensor batchOffset,
int maxFLen,
int blankIdx,
int opt,
bool packedInput){
auto scalarType = x.scalar_type();
auto tensorOpt = x.options();
const int batchSize = label.size(0);
const int maxGLen = label.size(1) + 1;
const int dictSize = x.size(-1);
TORCH_CHECK(blankIdx >= 0 and blankIdx < dictSize,
"Expected blank index to be in the range of 0 to ",
dictSize-1,
", but got ",
blankIdx);
TORCH_CHECK(opt == -1 or opt == 0 or opt == 1,
"Got an invalid optimization level ",
opt);
// The data type of alpha and beta will be resolved at dispatch time,
// hence defined here and assigned later
torch::Tensor alpha;
torch::Tensor beta;
torch::Tensor loss = torch::empty({batchSize}, tensorOpt);
const auto deviceProperties = at::cuda::getCurrentDeviceProperties();
const auto maxThreadPerBlock = deviceProperties->maxThreadsPerBlock;
const auto maxSmemPerBlock = deviceProperties->sharedMemPerBlock;
const auto batchOffsetPtr = packedInput ? batchOffset.data_ptr<int64_t>() : nullptr;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(scalarType, "transducer_loss_cuda_forward", ([&] {
// resolve accumulation type
using acc_t = at::acc_type<scalar_t, true>;
auto accType = c10::CppTypeToScalarType<acc_t>::value;
auto accTensorOpt = tensorOpt.dtype(accType);
alpha = torch::empty({batchSize, maxFLen, maxGLen}, accTensorOpt);
beta = torch::empty({batchSize, maxFLen, maxGLen}, accTensorOpt);
// decide what kernel to launch based on the problem size
// if the required SMEM size or number threads exceeds the limit, fall back to the vanilla
// kernel.
const auto smemSize = 2*maxGLen*sizeof(acc_t);
const auto optFallBack = (maxGLen > maxThreadPerBlock or smemSize > maxSmemPerBlock) ? 0
: (opt == -1) ? 1 : opt;
const int threads = std::min(maxThreadPerBlock, maxGLen);
const dim3 blocks(2, batchSize, 1);
if (optFallBack == 0)
transducer_loss_forward<<<blocks, threads, 0, stream>>>(
x.data_ptr<scalar_t>(),
label.data_ptr<int>(),
audLen.data_ptr<int>(),
txtLen.data_ptr<int>(),
batchOffsetPtr,
dictSize,
blankIdx,
maxFLen,
maxGLen,
packedInput,
alpha.data_ptr<acc_t>(),
beta.data_ptr<acc_t>(),
loss.data_ptr<scalar_t>());
else if (optFallBack == 1)
transducer_loss_batch_load_forward<scalar_t, acc_t, 4>
<<<blocks, threads, smemSize, stream>>>(
x.data_ptr<scalar_t>(),
label.data_ptr<int>(),
audLen.data_ptr<int>(),
txtLen.data_ptr<int>(),
batchOffsetPtr,
dictSize,
blankIdx,
maxFLen,
maxGLen,
packedInput,
alpha.data_ptr<acc_t>(),
beta.data_ptr<acc_t>(),
loss.data_ptr<scalar_t>());
}));
C10_CUDA_CHECK(cudaGetLastError());
return {alpha, beta, loss};
}
torch::Tensor transducer_loss_cuda_backward(
torch::Tensor x,
torch::Tensor lossGrad,
torch::Tensor alpha,
torch::Tensor beta,
torch::Tensor audLen,
torch::Tensor txtLen,
torch::Tensor label,
torch::Tensor batchOffset,
int maxFLen,
int blankIdx,
int opt,
bool fuseSoftmaxBackward,
bool packedInput){
auto dtype = x.scalar_type();
torch::Tensor xGrad;
const int batchSize = label.size(0);
const int maxGLen = label.size(1) + 1;
const int dictSize = x.size(-1);
const auto deviceProperties = at::cuda::getCurrentDeviceProperties();
const int maxThreadPerBlock = deviceProperties->maxThreadsPerBlock;
const int warpSize = deviceProperties->warpSize;
const auto batchOffsetPtr = packedInput ? batchOffset.data_ptr<int64_t>() : nullptr;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (fuseSoftmaxBackward){
// alloc empty tensors for performance, hence need to ensure zeros are writtern to
// don't-care region in the kernel.
xGrad = torch::empty_like(x);
// Would like each thread to work on 4 hidden units
const int workPerThread = 4;
// Don't want to have more than 128 threads per thread block
const int maxThreadPerElmt = std::min(128, maxThreadPerBlock);
const int threads = std::min(maxThreadPerElmt, std::max(warpSize,
(dictSize+workPerThread-1)/workPerThread));
const dim3 blocks(maxGLen, maxFLen, batchSize);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_loss_cuda_backward", ([&] {
using vec_t = uint64_t;
using acc_t = at::acc_type<scalar_t, true>;
constexpr int vectFactor = sizeof(vec_t) / sizeof(scalar_t);
constexpr int vecAlignment = std::alignment_of<vec_t>::value;
// if all input and output tensors meet the alignment requirement
bool memAlign = reinterpret_cast<uint64_t>(x.data_ptr<scalar_t>()) % vecAlignment == 0
and reinterpret_cast<uint64_t>(xGrad.data_ptr<scalar_t>())
% vecAlignment == 0;
if (vectFactor > 1 and dictSize%vectFactor == 0 and memAlign){
transducer_loss_fused_vec_backward<scalar_t, acc_t, vec_t, vectFactor>
<<<blocks, threads, 0, stream>>>(
x.data_ptr<scalar_t>(),
lossGrad.data_ptr<scalar_t>(),
audLen.data_ptr<int>(),
txtLen.data_ptr<int>(),
label.data_ptr<int>(),
alpha.data_ptr<acc_t>(),
beta.data_ptr<acc_t>(),
batchOffsetPtr,
dictSize,
blankIdx,
maxFLen,
maxGLen,
packedInput,
xGrad.data_ptr<scalar_t>());
}
else{
transducer_loss_fused_backward<<<blocks, threads, 0, stream>>>(
x.data_ptr<scalar_t>(),
lossGrad.data_ptr<scalar_t>(),
audLen.data_ptr<int>(),
txtLen.data_ptr<int>(),
label.data_ptr<int>(),
alpha.data_ptr<acc_t>(),
beta.data_ptr<acc_t>(),
batchOffsetPtr,
dictSize,
blankIdx,
maxFLen,
maxGLen,
packedInput,
xGrad.data_ptr<scalar_t>());
}
}));
}
else{
// for non-fused kernel, the gradients need to be writtern are very sparse, hence initialize
// the tensor with all zeros.
xGrad = torch::zeros_like(x);
// don't launch more threads than needed.
const int threads = std::min(maxThreadPerBlock, maxGLen);
const dim3 blocks(maxFLen, batchSize);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_loss_cuda_backward", ([&] {
using acc_t = at::acc_type<scalar_t, true>;
transducer_loss_backward<<<blocks, threads, 0, stream>>>(
x.data_ptr<scalar_t>(),
lossGrad.data_ptr<scalar_t>(),
audLen.data_ptr<int>(),
txtLen.data_ptr<int>(),
label.data_ptr<int>(),
alpha.data_ptr<acc_t>(),
beta.data_ptr<acc_t>(),
batchOffsetPtr,
dictSize,
blankIdx,
maxFLen,
maxGLen,
packedInput,
xGrad.data_ptr<scalar_t>());
}));
}
C10_CUDA_CHECK(cudaGetLastError());
return xGrad;
}
#include <torch/extension.h>
// CUDA forward declarations
std::vector<at::Tensor> softmax_xentropy_cuda(
const at::Tensor &input,
const at::Tensor &labels,
const float smoothing,
const bool half_to_float);
at::Tensor softmax_xentropy_backward_cuda(
const at::Tensor &grad_loss,
const at::Tensor &logits,
const at::Tensor &max_log_sum_exp,
const at::Tensor &labels,
const float smoothing);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> softmax_xentropy_forward(
const at::Tensor &input,
const at::Tensor &labels,
const float smoothing,
const bool half_to_float) {
CHECK_CUDA(input);
CHECK_INPUT(labels);
return softmax_xentropy_cuda(input, labels, smoothing, half_to_float);
}
at::Tensor softmax_xentropy_backward(
const at::Tensor &grad_loss,
const at::Tensor &logits,
const at::Tensor &max_log_sum_exp,
const at::Tensor &labels,
const float smoothing) {
CHECK_CUDA(grad_loss);
CHECK_CUDA(logits);
CHECK_INPUT(max_log_sum_exp);
CHECK_INPUT(labels);
return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, smoothing);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &softmax_xentropy_forward, "Softmax cross entropy loss with label smoothing forward (CUDA)");
m.def("backward", &softmax_xentropy_backward, "Softmax cross entropy loss with label smoothing backward (CUDA)");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment