Commit 8a915496 authored by Guolin Ke's avatar Guolin Ke
Browse files

first commit

parent 5cf7df97
The MIT License (MIT)
Copyright (c) 2022 DP Technology
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
\ No newline at end of file
Uni-Core, an efficient distributed PyTorch framework
====================================================
Uni-Core is built for rapidly creating PyTorch models with high performance, especially for Transfromer-based models. It supports the following features:
- Distributed training over multi-GPUs and multi-nodes
- Mixed-precision training with fp16 and bf16
- High-performance fused CUDA kernels
- model checkpoint management
- Friendly logging
- Buffered (GPU-CPU overlapping) data loader
- Gradient accumulation
- Commonly used optimizers and LR schedulers
- Easy to create new models
To install:
```python
python setup.py install
```
We recommend to use [docker](https://github.com/dptech-corp/Uni-Core/blob/main/docker/Dockerfile) for installation.
To build a model, you can refer to [example/bert](https://github.com/dptech-corp/Uni-Core/tree/main/examples/bert).
Related projects
----------------
- [Uni-Mol](https://github.com/dptech-corp/Uni-Mol)
Acknowledgement
---------------
The main framework is from [facebookresearch/fairseq](https://github.com/facebookresearch/fairseq).
The fused kernels are from [guolinke/fused_ops](https://github.com/guolinke/fused_ops).
Dockerfile is from [guolinke/pytorch-docker](https://github.com/guolinke/pytorch-docker).
License
-------
This project is licensed under the terms of the MIT license. See [LICENSE](https://github.com/dptech-corp/Uni-Core/blob/main/LICENSE) for additional details.
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/detail/IndexUtils.cuh"
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <cmath>
#include "ATen/TensorUtils.h"
#include "ATen/AccumulateType.h"
#include <ATen/cuda/Exceptions.h>
#include "type_shim.h"
template <typename T, typename GRAD_T>
__global__ void adam_cuda_kernel(
GRAD_T* __restrict__ p,
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T * __restrict__ g,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
const float decay_size)
{
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
for (int j = i; j < tsize; j+=totThreads) {
// weight decay
T cur_p = (T)p[j] * decay_size;
T scaled_grad = static_cast<T>(g[j]) / grad_scale;
m[j] = b1*m[j] + (1-b1)*scaled_grad;
v[j] = b2*v[j] + (1-b2)*scaled_grad*scaled_grad;
const float update = m[j] / (sqrtf(v[j]) + eps);
p[j] = cur_p - (step_size*update);
}
}
void fused_adam_cuda(
at::Tensor & p,
at::Tensor & m,
at::Tensor & v,
at::Tensor & g,
float lr,
float beta1,
float beta2,
float eps,
float grad_scale,
int step,
int bias_correction,
float decay)
{
//Get tensor size
int tsize = p.numel();
//Determine #threads and #blocks
const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//Constants
float step_size = lr;
if (bias_correction == 1) {
const double bias_correction1 = 1.0 - std::pow(static_cast<double>(beta1), step);
const double bias_correction2 = 1.0 - std::pow(static_cast<double>(beta2), step);
step_size = static_cast<float>(lr * std::sqrt(bias_correction2) / bias_correction1);
}
float decay_size = 1.0;
if (decay != 0.0) {
decay_size = 1.0 - step_size * decay;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.scalar_type() == at::ScalarType::Half || g.scalar_type() == at::ScalarType::BFloat16) {
AT_ASSERTM(p.scalar_type() == g.scalar_type(), "expected parameter to be the same type as grad");
using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF_AND_BF16(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.data_ptr<scalar_t_0>(),
m.data_ptr<accscalar_t>(),
v.data_ptr<accscalar_t>(),
g.data_ptr<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
decay_size);
);
} else {
using namespace at;
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
adam_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.data_ptr<scalar_t_0>(),
m.data_ptr<scalar_t_0>(),
v.data_ptr<scalar_t_0>(),
g.data_ptr<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
decay_size);
);
}
AT_CUDA_CHECK(cudaGetLastError());
}
#include <torch/extension.h>
void fused_adam_cuda(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int bias_correction, float decay);
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
void adam(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int bias_correction, float decay) {
CHECK_INPUT(p);
CHECK_INPUT(m);
CHECK_INPUT(v);
CHECK_INPUT(g);
int64_t num_elem = p.numel();
AT_ASSERTM(m.numel() == num_elem, "number of elements in m and p tensors should be equal");
AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal");
AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal");
fused_adam_cuda(p, m, v, g, lr, beta1, beta2, eps, grad_scale, step, bias_correction, decay);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("adam", &adam, "Adam optimized CUDA implementation.");
}
\ No newline at end of file
#include <torch/extension.h>
#include <vector>
#include <cassert>
namespace {
void compute_n1_n2(
at::Tensor input,
at::IntArrayRef normalized_shape,
int& n1,
int& n2)
{
int idiff = input.ndimension() - normalized_shape.size();
n2 = 1;
for (int i = 0; i < (int)normalized_shape.size(); ++i) {
assert( input.sizes()[i+idiff] == normalized_shape[i] );
n2 *= normalized_shape[i];
}
n1 = 1;
for (int i = 0; i < idiff; ++i) {
n1 *= input.sizes()[i];
}
}
void check_args(
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta
)
{
TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
}
void check_args(
at::Tensor input,
at::IntArrayRef normalized_shape,
int& n1,
int& n2
)
{
int64_t normalized_ndim = normalized_shape.size();
if (normalized_ndim < 1) {
std::stringstream ss;
ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
<< "containing at least one element, but got normalized_shape="
<< normalized_shape;
throw std::runtime_error(ss.str());
}
auto input_shape = input.sizes();
auto input_ndim = input.dim();
if (input_ndim < normalized_ndim ||
!input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) {
std::stringstream ss;
ss << "Given normalized_shape=" << normalized_shape
<< ", expected input with shape [*";
for (auto size : normalized_shape) {
ss << ", " << size;
}
ss << "], but got input of size" << input_shape;
throw std::runtime_error(ss.str());
}
compute_n1_n2(input,normalized_shape,n1,n2);
}
void check_args(
at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta,
int& n1,
int& n2
)
{
check_args(input,normalized_shape,n1,n2);
check_args(normalized_shape,gamma,beta);
}
}
void cuda_layer_norm(
at::Tensor* output,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
at::IntArrayRef normalized_shape,
at::Tensor* gamma,
at::Tensor* beta,
double epsilon);
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> layer_norm(
at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta,
double epsilon) {
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1,n2;
check_args(input,normalized_shape,gamma,beta,n1,n2);
TORCH_CHECK(n2 == 64 || n2 == 128 || n2 == 256 || n2 == 384 || n2 == 512 || n2 == 768 || n2 == 1024 || n2 == 1280 ||
n2 == 1536 || n2 == 1792 || n2 == 2048, "dimension is not supported");
at::Tensor output = at::empty_like(input);
at::Tensor mean = at::empty({n1}, input.options().dtype((input.scalar_type()==at::ScalarType::Half || input.scalar_type()==at::ScalarType::BFloat16) ? at::ScalarType::Float : input.scalar_type()));
at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,&gamma,&beta,epsilon);
return {output, mean, invvar};
}
void cuda_layer_norm_gradient(
at::Tensor* dout,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
at::IntArrayRef normalized_shape,
at::Tensor* gamma,
at::Tensor* beta,
double epsilon,
at::Tensor* grad_input
);
at::Tensor layer_norm_gradient(
at::Tensor dout,
at::Tensor mean,
at::Tensor invvar,
at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta,
double epsilon) {
CHECK_INPUT(dout);
CHECK_INPUT(mean);
CHECK_INPUT(invvar);
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1,n2;
check_args(input,normalized_shape,gamma,beta,n1,n2);
TORCH_CHECK(n2 == 64 || n2 == 128 || n2 == 256 || n2 == 384 || n2 == 512 || n2 == 768 || n2 == 1024 || n2 == 1280 ||
n2 == 1536 || n2 == 1792 || n2 == 2048, "dimension is not supported");
at::Tensor grad_input = at::empty_like(input);
cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2,
normalized_shape,&gamma,&beta,epsilon,
&grad_input);
return grad_input;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &layer_norm, "LayerNorm fast forward (CUDA)");
m.def("backward", &layer_norm_gradient, "LayerNorm fast backward (CUDA)");
}
#include <torch/extension.h>
#include <vector>
#include <cassert>
namespace {
void compute_n1_n2(
at::Tensor input,
at::IntArrayRef normalized_shape,
int& n1,
int& n2)
{
int idiff = input.ndimension() - normalized_shape.size();
n2 = 1;
for (int i = 0; i < (int)normalized_shape.size(); ++i) {
assert( input.sizes()[i+idiff] == normalized_shape[i] );
n2 *= normalized_shape[i];
}
n1 = 1;
for (int i = 0; i < idiff; ++i) {
n1 *= input.sizes()[i];
}
}
void check_args(
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta
)
{
TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
}
void check_args(
at::Tensor input,
at::IntArrayRef normalized_shape,
int& n1,
int& n2
)
{
int64_t normalized_ndim = normalized_shape.size();
if (normalized_ndim < 1) {
std::stringstream ss;
ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
<< "containing at least one element, but got normalized_shape="
<< normalized_shape;
throw std::runtime_error(ss.str());
}
auto input_shape = input.sizes();
auto input_ndim = input.dim();
if (input_ndim < normalized_ndim ||
!input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) {
std::stringstream ss;
ss << "Given normalized_shape=" << normalized_shape
<< ", expected input with shape [*";
for (auto size : normalized_shape) {
ss << ", " << size;
}
ss << "], but got input of size" << input_shape;
throw std::runtime_error(ss.str());
}
compute_n1_n2(input,normalized_shape,n1,n2);
}
void check_args(
at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta,
int& n1,
int& n2
)
{
check_args(input,normalized_shape,n1,n2);
check_args(normalized_shape,gamma,beta);
}
}
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
void cuda_layer_norm_gradient(
at::Tensor* dout,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
at::IntArrayRef normalized_shape,
at::Tensor* gamma,
at::Tensor* beta,
double epsilon,
at::Tensor* grad_gamma,
at::Tensor* grad_beta
);
std::vector<at::Tensor> layer_norm_gradient(
at::Tensor dout,
at::Tensor mean,
at::Tensor invvar,
at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta,
double epsilon) {
CHECK_INPUT(dout);
CHECK_INPUT(mean);
CHECK_INPUT(invvar);
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1,n2;
check_args(input,normalized_shape,gamma,beta,n1,n2);
TORCH_CHECK(n2 == 64 || n2 == 128 || n2 == 256 || n2 == 384 || n2 == 512 || n2 == 768 || n2 == 1024 || n2 == 1280 ||
n2 == 1536 || n2 == 1792 || n2 == 2048, "dimension is not supported");
at::Tensor grad_gamma = at::empty_like(gamma);
at::Tensor grad_beta = at::empty_like(beta);
cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2,
normalized_shape,&gamma,&beta,epsilon,
&grad_gamma,&grad_beta);
return {grad_gamma, grad_beta};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("backward", &layer_norm_gradient,
"LayerNorm fast backward for computing gamma and beta (CUDA)");
}
#include <iostream>
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h"
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include "util.h"
template <int Dim_, int VecSize_, int BatchesPerBlock_, int WarpsForOneBatchPerBlock_>
struct LNParameters {
static constexpr int Dim = Dim_;
static constexpr int VecSize = VecSize_;
static constexpr int WarpSize = 32;
static constexpr int BatchesPerBlock = BatchesPerBlock_;
static constexpr int WarpStride = WarpSize * VecSize;
static constexpr int WarpsForOneBatchPerBlock = WarpsForOneBatchPerBlock_;
static constexpr int Iterations = Dim / WarpStride / WarpsForOneBatchPerBlock;
static constexpr int BatchStride = Dim / WarpsForOneBatchPerBlock;
static constexpr int ThreadsPerBlock = BatchesPerBlock * WarpSize * WarpsForOneBatchPerBlock;
static_assert(Dim == WarpsForOneBatchPerBlock * WarpStride * Iterations, "");
static_assert(Dim == BatchStride * WarpsForOneBatchPerBlock, "");
};
template <typename IndexType, typename input_t, typename output_t, typename acc_t, typename Parameters>
__global__ void layernorm_forward(output_t *dst, const input_t *src, const input_t *gamma, const input_t *beta,
acc_t *mean, acc_t *invvar, IndexType bsz, acc_t epsilon) {
static_assert(Parameters::WarpsForOneBatchPerBlock == 1, "");
IndexType batch = blockIdx.x * Parameters::BatchesPerBlock + threadIdx.y;
if (batch < bsz) {
src += batch * Parameters::Dim + threadIdx.x * Parameters::VecSize;
dst += batch * Parameters::Dim + threadIdx.x * Parameters::VecSize;
gamma += threadIdx.x * Parameters::VecSize;
beta += threadIdx.x * Parameters::VecSize;
using VecInType = VecType<input_t, Parameters::VecSize>;
VecInType elements[Parameters::Iterations];
VecInType gamma_reg[Parameters::Iterations];
VecInType beta_reg[Parameters::Iterations];
#pragma unroll
for (int i = 0; i < Parameters::Iterations; ++i) {
elements[i] = *(VecInType *)(src + i * Parameters::WarpStride);
gamma_reg[i] = *(VecInType *)(gamma + i * Parameters::WarpStride);
beta_reg[i] = *(VecInType *)(beta + i * Parameters::WarpStride);
}
input_t *elements_l = (input_t *)elements;
input_t *gamma_l = (input_t *)gamma_reg;
input_t *beta_l = (input_t *)beta_reg;
acc_t sum = 0.0;
#pragma unroll
for (int i = 0; i < Parameters::Iterations * Parameters::VecSize; ++i) {
sum += (acc_t)elements_l[i];
}
#pragma unroll
for (int offset = Parameters::WarpSize / 2; offset > 0; offset /= 2) {
sum += SHFL_XOR(sum, offset, Parameters::WarpSize);
}
acc_t mu = sum / Parameters::Dim;
acc_t var = 0.0;
#pragma unroll
for (int i = 0; i < Parameters::Iterations * Parameters::VecSize; ++i) {
acc_t diff = (acc_t)elements_l[i] - mu;
var += diff * diff;
}
#pragma unroll
for (int offset = Parameters::WarpSize / 2; offset > 0; offset /= 2) {
var += SHFL_XOR(var, offset, Parameters::WarpSize);
}
const acc_t rsigma = rsqrtf(var / Parameters::Dim + epsilon);
#pragma unroll
for (int i = 0; i < Parameters::Iterations * Parameters::VecSize; ++i) {
elements_l[i] = (input_t)(((acc_t)elements_l[i] - mu) * rsigma) * gamma_l[i] + beta_l[i];
}
#pragma unroll
for (int i = 0; i < Parameters::Iterations; ++i) {
*(VecInType *)(dst + i * Parameters::WarpStride) = elements[i];
}
if (threadIdx.x == 0) {
mean[batch] = mu;
invvar[batch] = rsigma;
}
}
}
template <typename IndexType, typename input_t, typename output_t, typename acc_t, typename Parameters>
__global__ void layernorm_backward_x(output_t *dst, const input_t *input, const input_t *grad, const input_t *gamma,
const acc_t *mean, const acc_t *invvar, IndexType bsz) {
IndexType batch = blockIdx.x * Parameters::BatchesPerBlock + threadIdx.y;
if (batch < bsz) {
input += batch * Parameters::Dim + threadIdx.x * Parameters::VecSize;
dst += batch * Parameters::Dim + threadIdx.x * Parameters::VecSize;
grad += batch * Parameters::Dim + threadIdx.x * Parameters::VecSize;
gamma += threadIdx.x * Parameters::VecSize;
using VecInType = VecType<input_t, Parameters::VecSize>;
VecInType elements[Parameters::Iterations];
VecInType grad_reg[Parameters::Iterations];
VecInType gamma_reg[Parameters::Iterations];
#pragma unroll
for (int i = 0; i < Parameters::Iterations; ++i) {
elements[i] = *(VecInType *)(input + i * Parameters::WarpStride);
grad_reg[i] = *(VecInType *)(grad + i * Parameters::WarpStride);
gamma_reg[i] = *(VecInType *)(gamma + i * Parameters::WarpStride);
}
input_t *elements_l = (input_t *)elements;
input_t *grad_l = (input_t *)grad_reg;
input_t *gamma_l = (input_t *)gamma_reg;
const acc_t mu = mean[batch];
const acc_t var = invvar[batch];
acc_t sum1 = 0.0, sum2 = 0.0;
#pragma unroll
for (int i = 0; i < Parameters::Iterations * Parameters::VecSize; ++i) {
elements_l[i] = elements_l[i] - (input_t)mu;
sum1 += (acc_t)(elements_l[i] * grad_l[i] * gamma_l[i]);
sum2 += (acc_t)(grad_l[i] * gamma_l[i]);
}
#pragma unroll
for (int offset = Parameters::WarpSize / 2; offset > 0; offset /= 2) {
sum1 += SHFL_XOR(sum1, offset, Parameters::WarpSize);
}
#pragma unroll
for (int offset = Parameters::WarpSize / 2; offset > 0; offset /= 2) {
sum2 += SHFL_XOR(sum2, offset, Parameters::WarpSize);
}
sum1 *= var * var * var / Parameters::Dim;
sum2 *= var / Parameters::Dim;
#pragma unroll
for (int i = 0; i < Parameters::Iterations * Parameters::VecSize; ++i) {
elements_l[i] = grad_l[i] * gamma_l[i] * (input_t)var - (input_t)sum1 * elements_l[i] - (input_t)sum2;
}
#pragma unroll
for (int i = 0; i < Parameters::Iterations; ++i) {
*(VecInType *)(dst + i * Parameters::WarpStride) = elements[i];
}
}
}
#define LAUNCH_FORWARD_KERNEL(len, vec, batches, type) \
{ \
dim3 threads(32, batches); \
int blocks = DIV_CELL(n1, batches); \
layernorm_forward<size_t, type, type, float, LNParameters<len, vec, batches, 1>> \
<<<blocks, threads, 0, stream>>> \
((type *)output->data_ptr(), (type *)input->data_ptr(), (type *)gamma->data_ptr(), \
(type *)beta->data_ptr(), (float *)mean->data_ptr(), (float *)invvar->data_ptr(), n1, epsilon); \
break; \
}
#define LAUNCH_BACKWARD_KERNEL(len, vec, batches, type) \
{ \
dim3 threads(32, batches); \
int blocks = DIV_CELL(n1, batches); \
layernorm_backward_x<size_t, type, type, float, LNParameters<len, vec, batches, 1>> \
<<<blocks, threads, 0, stream>>> \
((type *)grad_input->data_ptr(), (type *)input->data_ptr(), (type *)dout->data_ptr(), \
(type *)gamma->data_ptr(), (float *)mean->data_ptr(), (float *)invvar->data_ptr(), n1); \
break; \
}
void cuda_layer_norm(
at::Tensor* output,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
at::IntArrayRef normalized_shape,
at::Tensor* gamma,
at::Tensor* beta,
double epsilon)
{
using namespace at;
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
auto type = input->scalar_type();
if (type == at::ScalarType::BFloat16) {
switch (n2) {
case 64: LAUNCH_FORWARD_KERNEL(64, 2, 4, nv_bfloat16)
case 128: LAUNCH_FORWARD_KERNEL(128, 2, 4, nv_bfloat16)
case 256: LAUNCH_FORWARD_KERNEL(256, 2, 4, nv_bfloat16)
case 384: LAUNCH_FORWARD_KERNEL(384, 2, 4, nv_bfloat16)
case 512: LAUNCH_FORWARD_KERNEL(512, 2, 4, nv_bfloat16)
case 768: LAUNCH_FORWARD_KERNEL(768, 2, 4, nv_bfloat16)
case 1024: LAUNCH_FORWARD_KERNEL(1024, 2, 4, nv_bfloat16)
case 1280: LAUNCH_FORWARD_KERNEL(1280, 2, 4, nv_bfloat16)
case 1536: LAUNCH_FORWARD_KERNEL(1536, 2, 4, nv_bfloat16)
case 1792: LAUNCH_FORWARD_KERNEL(1792, 2, 4, nv_bfloat16)
case 2048: LAUNCH_FORWARD_KERNEL(2048, 2, 4, nv_bfloat16)
}
} else if (type == at::ScalarType::Half) {
switch (n2) {
case 64: LAUNCH_FORWARD_KERNEL(64, 2, 4, half)
case 128: LAUNCH_FORWARD_KERNEL(128, 2, 4, half)
case 256: LAUNCH_FORWARD_KERNEL(256, 2, 4, half)
case 384: LAUNCH_FORWARD_KERNEL(384, 2, 4, half)
case 512: LAUNCH_FORWARD_KERNEL(512, 2, 4, half)
case 768: LAUNCH_FORWARD_KERNEL(768, 2, 4, half)
case 1024: LAUNCH_FORWARD_KERNEL(1024, 2, 4, half)
case 1280: LAUNCH_FORWARD_KERNEL(1280, 2, 4, half)
case 1536: LAUNCH_FORWARD_KERNEL(1536, 2, 4, half)
case 1792: LAUNCH_FORWARD_KERNEL(1792, 2, 4, half)
case 2048: LAUNCH_FORWARD_KERNEL(2048, 2, 4, half)
}
} else if (type == at::ScalarType::Float) {
switch (n2) {
case 64: LAUNCH_FORWARD_KERNEL(64, 1, 4, float)
case 128: LAUNCH_FORWARD_KERNEL(128, 1, 4, float)
case 256: LAUNCH_FORWARD_KERNEL(256, 1, 4, float)
case 384: LAUNCH_FORWARD_KERNEL(384, 1, 4, float)
case 512: LAUNCH_FORWARD_KERNEL(512, 1, 4, float)
case 768: LAUNCH_FORWARD_KERNEL(768, 1, 4, float)
case 1024: LAUNCH_FORWARD_KERNEL(1024, 1, 4, float)
case 1280: LAUNCH_FORWARD_KERNEL(1280, 1, 4, float)
case 1536: LAUNCH_FORWARD_KERNEL(1536, 1, 4, float)
case 1792: LAUNCH_FORWARD_KERNEL(1792, 1, 4, float)
case 2048: LAUNCH_FORWARD_KERNEL(2048, 1, 4, float)
}
}
}
void cuda_layer_norm_gradient(
at::Tensor* dout,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
at::IntArrayRef normalized_shape,
at::Tensor* gamma,
at::Tensor* beta,
double epsilon,
at::Tensor* grad_input)
{
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
auto type = input->scalar_type();
if (type == at::ScalarType::BFloat16) {
switch (n2) {
case 64: LAUNCH_BACKWARD_KERNEL(64, 2, 4, nv_bfloat16)
case 128: LAUNCH_BACKWARD_KERNEL(128, 2, 4, nv_bfloat16)
case 256: LAUNCH_BACKWARD_KERNEL(256, 2, 4, nv_bfloat16)
case 384: LAUNCH_BACKWARD_KERNEL(384, 2, 4, nv_bfloat16)
case 512: LAUNCH_BACKWARD_KERNEL(512, 2, 4, nv_bfloat16)
case 768: LAUNCH_BACKWARD_KERNEL(768, 2, 4, nv_bfloat16)
case 1024: LAUNCH_BACKWARD_KERNEL(1024, 2, 4, nv_bfloat16)
case 1280: LAUNCH_BACKWARD_KERNEL(1280, 2, 4, nv_bfloat16)
case 1536: LAUNCH_BACKWARD_KERNEL(1536, 2, 4, nv_bfloat16)
case 1792: LAUNCH_BACKWARD_KERNEL(1792, 2, 4, nv_bfloat16)
case 2048: LAUNCH_BACKWARD_KERNEL(2048, 2, 4, nv_bfloat16)
}
} else if (type == at::ScalarType::Half) {
switch (n2) {
case 64: LAUNCH_BACKWARD_KERNEL(64, 2, 4, half)
case 128: LAUNCH_BACKWARD_KERNEL(128, 2, 4, half)
case 256: LAUNCH_BACKWARD_KERNEL(256, 2, 4, half)
case 384: LAUNCH_BACKWARD_KERNEL(384, 2, 4, half)
case 512: LAUNCH_BACKWARD_KERNEL(512, 2, 4, half)
case 768: LAUNCH_BACKWARD_KERNEL(768, 2, 4, half)
case 1024: LAUNCH_BACKWARD_KERNEL(1024, 2, 4, half)
case 1280: LAUNCH_BACKWARD_KERNEL(1280, 2, 4, half)
case 1536: LAUNCH_BACKWARD_KERNEL(1536, 2, 4, half)
case 1792: LAUNCH_BACKWARD_KERNEL(1792, 2, 4, half)
case 2048: LAUNCH_BACKWARD_KERNEL(2048, 2, 4, half)
}
} else if (type == at::ScalarType::Float) {
switch (n2) {
case 64: LAUNCH_BACKWARD_KERNEL(64, 1, 4, float)
case 128: LAUNCH_BACKWARD_KERNEL(128, 1, 4, float)
case 256: LAUNCH_BACKWARD_KERNEL(256, 1, 4, float)
case 384: LAUNCH_BACKWARD_KERNEL(384, 1, 4, float)
case 512: LAUNCH_BACKWARD_KERNEL(512, 1, 4, float)
case 768: LAUNCH_BACKWARD_KERNEL(768, 1, 4, float)
case 1024: LAUNCH_BACKWARD_KERNEL(1024, 1, 4, float)
case 1280: LAUNCH_BACKWARD_KERNEL(1280, 1, 4, float)
case 1536: LAUNCH_BACKWARD_KERNEL(1536, 1, 4, float)
case 1792: LAUNCH_BACKWARD_KERNEL(1792, 1, 4, float)
case 2048: LAUNCH_BACKWARD_KERNEL(2048, 1, 4, float)
}
}
}
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h"
#include <THC/THCDeviceUtils.cuh>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include "type_shim.h"
namespace {
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
// template <typename T>
// struct SharedMemory
// {
// // Ensure that we won't compile any un-specialized types
// __device__ T *getPointer()
// {
// extern __device__ void error(void);
// error();
// return NULL;
// }
// };
// https://github.com/NVIDIA/apex/issues/246
template <typename T>
struct SharedMemory;
template <>
struct SharedMemory <float>
{
__device__ float *getPointer()
{
extern __shared__ float s_float[];
return s_float;
}
};
template <>
struct SharedMemory <double>
{
__device__ double *getPointer()
{
extern __shared__ double s_double[];
return s_double;
}
};
}
template<typename T, typename U> __device__
void cuLoadWriteStridedInputs(
const int i1_block,
const int thr_load_row_off,
const int thr_load_col_off,
const int i2_off,
const int row_stride,
U* warp_buf1,
U* warp_buf2,
const T* input,
const T* dout,
const int i1_end,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar
)
{
int i1 = i1_block+thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
const int i2 = i2_off + k;
const int load_idx = i1*n2+i2;
const int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
if (i2<n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] = curr_dout;
warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;
} else {
warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0);
}
}
} else {
for (int k = 0; k < blockDim.y; ++k) {
const int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0);
}
}
}
template<typename T, typename U> __device__
void cuLoadAddStridedInputs(
const int i1_block,
const int thr_load_row_off,
const int thr_load_col_off,
const int i2_off,
const int row_stride,
U* warp_buf1,
U* warp_buf2,
const T* input,
const T* dout,
const int i1_end,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar
)
{
int i1 = i1_block+thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
const int i2 = i2_off + k;
const int load_idx = i1*n2+i2;
const int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
if (i2<n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] += curr_dout;
warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;
}
}
}
}
template<typename T, typename U> __global__
void cuComputePartGradGammaBeta(
const T* __restrict__ dout,
const T* __restrict__ input,
const int n1,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar,
U epsilon,
U* part_grad_gamma,
U* part_grad_beta)
{
const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y);
const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y;
const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y;
const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
const int row_stride = blockDim.x+1;
const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1);
const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y;
const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
SharedMemory<U> shared;
U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
U* warp_buf1 = (U*)buf;
U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
// compute partial sums from strided inputs
// do this to increase number of loads in flight
cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) {
cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
}
__syncthreads();
// inter-warp reductions
// sum within each warp
U acc1 = U(0);
U acc2 = U(0);
for (int k = 0; k < blockDim.y; ++k) {
const int row1 = threadIdx.y + k*blockDim.y;
const int idx1 = row1*row_stride + threadIdx.x;
acc1 += warp_buf1[idx1];
acc2 += warp_buf2[idx1];
}
warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1;
warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2;
__syncthreads();
// sum all warps
for (int offset = blockDim.y/2; offset > 1; offset /= 2) {
if (threadIdx.y < offset) {
const int row1 = threadIdx.y;
const int row2 = threadIdx.y + offset;
const int idx1 = row1*row_stride + threadIdx.x;
const int idx2 = row2*row_stride + threadIdx.x;
warp_buf1[idx1] += warp_buf1[idx2];
warp_buf2[idx1] += warp_buf2[idx2];
}
__syncthreads();
}
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (threadIdx.y == 0 && i2 < n2) {
const int row1 = threadIdx.y;
const int row2 = threadIdx.y + 1;
const int idx1 = row1*row_stride + threadIdx.x;
const int idx2 = row2*row_stride + threadIdx.x;
part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2];
part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2];
}
}
template<typename T, typename U> __global__
void cuComputeGradGammaBeta(
const U* part_grad_gamma,
const U* part_grad_beta,
const int part_size,
const int n1,
const int n2,
T* grad_gamma,
T* grad_beta)
{
// sum partial gradients for gamma and beta
SharedMemory<U> shared;
U* buf = shared.getPointer();
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (i2 < n2) {
// each warp does sequential reductions until reduced part_size is num_warps
int num_warp_reductions = part_size / blockDim.y;
U sum_gamma = U(0);
U sum_beta = U(0);
const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) {
sum_gamma += part_grad_gamma_ptr[warp_offset*n2];
sum_beta += part_grad_beta_ptr[warp_offset*n2];
}
// inter-warp reductions
const int nbsize3 = blockDim.x * blockDim.y / 2;
for (int offset = blockDim.y/2; offset >= 1; offset /= 2) {
// top half write to shared memory
if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[write_idx] = sum_gamma;
buf[write_idx+nbsize3] = sum_beta;
}
__syncthreads();
// bottom half sums
if (threadIdx.y < offset) {
const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
sum_gamma += buf[read_idx];
sum_beta += buf[read_idx+nbsize3];
}
__syncthreads();
}
// write out fully summed gradients
if (threadIdx.y == 0) {
grad_gamma[i2] = sum_gamma;
grad_beta[i2] = sum_beta;
}
}
}
template<typename T, typename U>
void HostLayerNormGradient(
const T* dout,
const U* mean,
const U* invvar,
at::Tensor* input,
int n1,
int n2,
const T* gamma,
const T* beta,
double epsilon,
T* grad_gamma,
T* grad_beta
)
{
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (gamma != NULL && beta != NULL) {
// compute grad_gamma(j) and grad_beta(j)
const int part_size = 16;
const dim3 threads2(32,4,1);
const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1);
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype((input->scalar_type()==at::ScalarType::Half || input->scalar_type()==at::ScalarType::BFloat16) ? at::ScalarType::Float : input->scalar_type()));
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
dout,
input->data_ptr<T>(),
n1,n2,
mean,
invvar,
U(epsilon),
part_grad_gamma.data_ptr<U>(),
part_grad_beta.data_ptr<U>());
const dim3 threads3(32,8,1);
const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1);
const int nshared3 = threads3.x * threads3.y * sizeof(U);
cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
part_grad_gamma.data_ptr<U>(),
part_grad_beta.data_ptr<U>(),
part_size,
n1,n2,
grad_gamma,
grad_beta);
}
}
void cuda_layer_norm_gradient(
at::Tensor* dout,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
at::IntArrayRef normalized_shape,
at::Tensor* gamma,
at::Tensor* beta,
double epsilon,
at::Tensor* grad_gamma,
at::Tensor* grad_beta)
{
using namespace at;
DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BF16(input->scalar_type(), 0, "cuComputeGradInput",
using accscalar_t = at::acc_type<scalar_t_0, true>;
HostLayerNormGradient(
dout->data_ptr<scalar_t_0>(),
mean->data_ptr<accscalar_t>(),
invvar->data_ptr<accscalar_t>(),
input,
n1,n2,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
gamma->data_ptr<scalar_t_0>(),
beta->data_ptr<scalar_t_0>(),
epsilon,
grad_gamma->data_ptr<scalar_t_0>(),
grad_beta->data_ptr<scalar_t_0>());
)
}
\ No newline at end of file
#include <torch/extension.h>
at::Tensor multi_tensor_l2norm_cuda(
int chunk_size,
std::vector<std::vector<at::Tensor>> tensor_lists);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("l2norm", &multi_tensor_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors");
}
\ No newline at end of file
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include <assert.h>
#include <iostream>
constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
template<int n> struct TensorListMetadata
{
void* addresses[n][depth_to_max_tensors[n-1]];
int sizes[depth_to_max_tensors[n-1]];
unsigned char block_to_tensor[depth_to_max_blocks[n-1]];
int block_to_chunk[depth_to_max_blocks[n-1]];
int start_tensor_this_launch;
};
template<typename T, typename U, typename... ArgTypes>
__global__ void multi_tensor_apply_kernel(
int chunk_size,
T tl,
U callable,
ArgTypes... args)
{
callable(chunk_size, tl, args...);
}
template<int depth, typename T, typename... ArgTypes>
void multi_tensor_apply(
int block_size,
int chunk_size,
const std::vector<std::vector<at::Tensor>>& tensor_lists,
T callable,
ArgTypes... args)
{
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
int len0 = tensor_lists[0].size();
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
auto ref_device = tensor_lists[0][0].device();
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
auto ref_dtype = tensor_lists[0][0].scalar_type();
for (int l = 0; l < tensor_lists.size(); l++)
{
TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
for(int t = 0; t < tensor_lists[l].size(); t++)
{
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
#ifdef VERSION_GE_1_5
contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast) || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast3d));
#endif
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
TORCH_CHECK(tensor_lists[l][t].device() == ref_device, "A tensor was not on the same device as the first tensor");
TORCH_CHECK(tensor_lists[l][t].scalar_type() == ref_dtype, "A tensor was not the same dtype as the first tensor");
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
}
}
int ntensors = tensor_lists[0].size();
TensorListMetadata<depth> tl;
const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
auto stream = at::cuda::getCurrentCUDAStream();
tl.start_tensor_this_launch = 0;
int loc_block_info = 0;
int loc_tensor_info = 0;
for(int t = 0; t < ntensors; t++)
{
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
for(int d = 0; d < depth; d++)
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
loc_tensor_info++;
int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;
for(int chunk = 0; chunk < chunks_this_tensor; chunk++)
{
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tl.block_to_chunk[loc_block_info] = chunk;
loc_block_info++;
bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth-1] &&
chunk == chunks_this_tensor - 1);
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth-1]);
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
if(tensors_full || blocks_full || last_chunk)
{
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
chunk_size,
tl,
callable,
args...);
AT_CUDA_CHECK(cudaGetLastError());
loc_block_info = 0;
if(chunk == chunks_this_tensor - 1)
{
loc_tensor_info = 0;
tl.start_tensor_this_launch = t + 1;
}
else
{
tl.sizes[0] = tl.sizes[loc_tensor_info-1];
for(int d = 0; d < depth; d++)
tl.addresses[d][0] = tl.addresses[d][loc_tensor_info-1];
loc_tensor_info = 1;
tl.start_tensor_this_launch = t;
}
}
}
}
}
\ No newline at end of file
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_bf16.h>
#include <assert.h>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}
template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
template<typename x_t>
struct L2NormFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListMetadata<1>& tl,
float* output)
{
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
x_t* x = (x_t*)tl.addresses[0][tensor_loc];
x += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
__shared__ float s_vals[512];
float vals[ILP];
x_t r_x[ILP];
for(int i = 0; i < ILP; i++)
{
vals[i] = 0.0f;
r_x[i] = (x_t)0.0f;
}
if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x))
{
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
{
// load
load_store(r_x, x, 0 , i_start);
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
float next = static_cast<float>(r_x[ii]);
vals[ii] += next*next;
}
}
}
else
{
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
{
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
float next = static_cast<float>(x[i]);
vals[ii] += next*next;
}
}
}
}
float val = 0.f;
for(int i = 0; i < ILP; i++)
val += vals[i];
float res = reduce_block_into_lanes(s_vals, val);
if(threadIdx.x == 0)
{
output[blockIdx.x] += res;
}
}
};
__global__ void cleanup(
float* output,
float* ret)
{
__shared__ float vals[512];
if(blockIdx.x == 0)
{
float val = 0;
if(threadIdx.x < 320)
val = output[threadIdx.x];
float final = reduce_block_into_lanes(vals, val);
if(threadIdx.x == 0)
*ret = sqrt(final);
}
}
at::Tensor multi_tensor_l2norm_cuda(
int chunk_size,
std::vector<std::vector<at::Tensor>> tensor_lists)
{
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
auto output = at::zeros({320}, float_options);
switch (tensor_lists[0][0].scalar_type()){
case at::ScalarType::Float: {
multi_tensor_apply<1>(
BLOCK_SIZE,
chunk_size,
tensor_lists,
L2NormFunctor<float>(),
output.data_ptr<float>()
);
break;
}
case at::ScalarType::Half: {
multi_tensor_apply<1>(
BLOCK_SIZE,
chunk_size,
tensor_lists,
L2NormFunctor<half>(),
output.data_ptr<float>()
);
break;
}
case at::ScalarType::BFloat16: {
multi_tensor_apply<1>(
BLOCK_SIZE,
chunk_size,
tensor_lists,
L2NormFunctor<nv_bfloat16>(),
output.data_ptr<float>()
);
break;
}
}
AT_CUDA_CHECK(cudaGetLastError());
auto ret = at::empty({1}, output.options());
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
auto stream = at::cuda::getCurrentCUDAStream();
cleanup<<<1, 512, 0, stream>>>(
output.data_ptr<float>(),
ret.data_ptr<float>());
return ret;
}
#include <vector>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <c10/cuda/CUDAMathCompat.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <curand_kernel.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <math.h>
#include <iostream>
union float_int_32
{
uint32_t i;
float f;
};
__global__ void fp32_to_bf16(
const float* input,
nv_bfloat16* output,
const int tsize,
uint64_t seed,
uint64_t offset) {
int i = threadIdx.x + blockIdx.x * blockDim.x;
if (i < tsize) {
float_int_32 d;
d.f = input[i];
curandStatePhilox4_32_10_t state;
curand_init(seed, i, offset, &state);
d.i += curand(&state) & 0x0000ffff;
output[i] = __float2bfloat16_rz(d.f);
}
}
void fused_fp32_to_bf16_sr_cuda(
at::Tensor & input,
at::Tensor & output)
{
int tsize = input.numel();
const int threadsPerBlock = 512;
const int blocks = (tsize + threadsPerBlock - 1) / threadsPerBlock;
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(input), "parameter tensor is too large to be indexed with int32");
AT_ASSERTM(input.scalar_type() == at::ScalarType::Float, "expected input to be float32 tensor");
AT_ASSERTM(output.scalar_type() == at::ScalarType::BFloat16, "expected output to be bfloat16 tensor");
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen.mutex());
rng_engine_inputs = at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(1);
}
uint64_t seed = std::get<0>(rng_engine_inputs);
uint64_t offset = std::get<1>(rng_engine_inputs);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
fp32_to_bf16<<<blocks, threadsPerBlock, 0, stream>>>(
(const float*)input.data_ptr(),
(nv_bfloat16*)output.data_ptr(),
tsize,
seed,
offset);
AT_CUDA_CHECK(cudaGetLastError());
}
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
void fused_fp32_to_bf16_sr_cuda(at::Tensor & input, at::Tensor & output);
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
void fused_fp32_to_bf16_sr(at::Tensor & input, at::Tensor & output) {
CHECK_INPUT(input);
CHECK_INPUT(output);
int64_t num_elem = input.numel();
AT_ASSERTM(output.numel() == num_elem, "number of elements in input ond output tensors should be equal");
fused_fp32_to_bf16_sr_cuda(input, output);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fp32_to_bf16_sr", &fused_fp32_to_bf16_sr, "fused fp32 to bf16 random rounding");
}
\ No newline at end of file
#include <torch/extension.h>
#include <ATen/Generator.h>
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <vector>
std::vector<c10::optional<torch::Tensor>> fwd_cuda(
bool is_training,
const torch::Tensor &input,
float dropout_prob,
c10::optional<at::Generator> gen_
);
torch::Tensor bwd_cuda(
torch::Tensor &output_grads,
const torch::Tensor &softmax_results,
const c10::optional<torch::Tensor> &dropout_mask,
float dropout_prob
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<c10::optional<torch::Tensor>> fwd(
bool is_training,
const torch::Tensor &input,
float dropout_prob,
c10::optional<at::Generator> gen_
) {
CHECK_INPUT(input);
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.scalar_type() == at::ScalarType::Half ||
input.scalar_type() == at::ScalarType::BFloat16 ||
input.scalar_type() == at::ScalarType::Float, "Only HALF/BFloat16/Float is supported");
return fwd_cuda(is_training, input, dropout_prob, gen_);
}
torch::Tensor bwd(
torch::Tensor &output_grads,
const torch::Tensor &softmax_results,
const c10::optional<torch::Tensor> &dropout_mask,
float dropout_prob
) {
CHECK_INPUT(output_grads);
CHECK_INPUT(softmax_results);
if (dropout_mask) {
CHECK_INPUT(dropout_mask.value());
}
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(!dropout_mask || dropout_mask->dim() == 1, "expected 1D tensor");
AT_ASSERTM(output_grads.scalar_type() == at::ScalarType::Half ||
output_grads.scalar_type() == at::ScalarType::BFloat16 ||
output_grads.scalar_type() == at::ScalarType::Float, "Only HALF/BFloat16/Float is supported");
AT_ASSERTM(softmax_results.scalar_type() == at::ScalarType::Half ||
softmax_results.scalar_type() == at::ScalarType::BFloat16 ||
softmax_results.scalar_type() == at::ScalarType::Float, "Only HALF/BFloat16/Float is supported");
AT_ASSERTM(output_grads.scalar_type() == softmax_results.scalar_type(), "the types mismatch");
return bwd_cuda(output_grads, softmax_results, dropout_mask, dropout_prob);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &fwd, "softmax dropout -- Forward.");
m.def("backward", &bwd, "softmax dropout -- Backward.");
}
#include <vector>
#include <iostream>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <math.h>
#include "softmax_fast.h"
std::vector<c10::optional<torch::Tensor>> fwd_cuda(
bool is_training,
const torch::Tensor &input,
float dropout_prob,
c10::optional<at::Generator> gen_
) {
const int attn_batches = input.size(0);
const int q_seq_len = input.size(1);
const int k_seq_len = input.size(2);
auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(softmax_mask_dtype(k_seq_len));
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void *input_ptr = reinterpret_cast<void *>(input.data_ptr());
void *softmax_results_ptr = reinterpret_cast<void *>(softmax_results.data_ptr());
// Padded Softmax
bool softmax_success = false;
auto scalar_type = input.scalar_type();
if (is_training && dropout_prob > 0.0f) {
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty(
{softmax_mask_size(attn_batches * q_seq_len, k_seq_len)}, mask_options
);
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_engine_inputs(softmax_rng_delta_offset(k_seq_len));
}
uint64_t seed = std::get<0>(rng_engine_inputs);
uint64_t offset = std::get<1>(rng_engine_inputs);
if (scalar_type == at::ScalarType::BFloat16){
softmax_success = dispatch_softmax_forward<nv_bfloat16, nv_bfloat16, float, true>(
reinterpret_cast<nv_bfloat16 *>(dropout_results.data_ptr()),
reinterpret_cast<nv_bfloat16 *>(softmax_results_ptr),
reinterpret_cast<const nv_bfloat16 *>(input_ptr),
reinterpret_cast<void *>(dropout_mask.data_ptr()),
1.0f - dropout_prob,
k_seq_len,
attn_batches*q_seq_len, seed, offset);
} else if (scalar_type == at::ScalarType::Half){
softmax_success = dispatch_softmax_forward<half, half, float, true>(
reinterpret_cast<half *>(dropout_results.data_ptr()),
reinterpret_cast<half *>(softmax_results_ptr),
reinterpret_cast<const half *>(input_ptr),
reinterpret_cast<void *>(dropout_mask.data_ptr()),
1.0f - dropout_prob,
k_seq_len,
attn_batches*q_seq_len, seed, offset);
} else if (scalar_type == at::ScalarType::Float){
softmax_success = dispatch_softmax_forward<float, float, float, true>(
reinterpret_cast<float *>(dropout_results.data_ptr()),
reinterpret_cast<float *>(softmax_results_ptr),
reinterpret_cast<const float *>(input_ptr),
reinterpret_cast<void *>(dropout_mask.data_ptr()),
1.0f - dropout_prob,
k_seq_len,
attn_batches*q_seq_len, seed, offset);
} else {
softmax_success = false;
}
if (softmax_success) {
return {dropout_results, dropout_mask, softmax_results};
} else {
return {c10::optional<torch::Tensor>(), c10::optional<torch::Tensor>(), c10::optional<torch::Tensor>()};
}
} else {
if (scalar_type == at::ScalarType::BFloat16){
softmax_success = dispatch_softmax_forward<nv_bfloat16, nv_bfloat16, float, false>(
reinterpret_cast<nv_bfloat16 *>(softmax_results_ptr),
nullptr,
reinterpret_cast<const nv_bfloat16 *>(input_ptr),
nullptr,
1.0,
k_seq_len,
attn_batches*q_seq_len, 0, 0);
} else if (scalar_type == at::ScalarType::Half){
softmax_success = dispatch_softmax_forward<half, half, float, false>(
reinterpret_cast<half *>(softmax_results_ptr),
nullptr,
reinterpret_cast<const half *>(input_ptr),
nullptr,
1.0,
k_seq_len,
attn_batches*q_seq_len, 0, 0);
} else if (scalar_type == at::ScalarType::Float){
softmax_success = dispatch_softmax_forward<float, float, float, false>(
reinterpret_cast<float *>(softmax_results_ptr),
nullptr,
reinterpret_cast<const float *>(input_ptr),
nullptr,
1.0,
k_seq_len,
attn_batches*q_seq_len, 0, 0);
} else {
softmax_success = false;
}
if (softmax_success) {
return {softmax_results, c10::optional<torch::Tensor>(), softmax_results};
} else {
return {c10::optional<torch::Tensor>(), c10::optional<torch::Tensor>(), c10::optional<torch::Tensor>()};
}
}
}
torch::Tensor bwd_cuda(
torch::Tensor &output_grads,
const torch::Tensor &softmax_results,
const c10::optional<torch::Tensor> &dropout_mask,
float dropout_prob
)
{
const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1);
const int k_seq_len = output_grads.size(2);
auto scalar_type = output_grads.scalar_type();
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
if (dropout_mask) {
if (scalar_type == at::ScalarType::BFloat16){
dispatch_softmax_backward<nv_bfloat16, nv_bfloat16, float, false, true>(
reinterpret_cast<nv_bfloat16 *>(output_grads.data_ptr()),
reinterpret_cast<const nv_bfloat16 *>(output_grads.data_ptr()),
reinterpret_cast<const nv_bfloat16 *>(softmax_results.data_ptr()),
reinterpret_cast<const void *>(dropout_mask->data_ptr()),
1.0f - dropout_prob,
k_seq_len,
attn_batches*q_seq_len);
} else if (scalar_type == at::ScalarType::Half){
dispatch_softmax_backward<half, half, float, false, true>(
reinterpret_cast<half *>(output_grads.data_ptr()),
reinterpret_cast<const half *>(output_grads.data_ptr()),
reinterpret_cast<const half *>(softmax_results.data_ptr()),
reinterpret_cast<const void *>(dropout_mask->data_ptr()),
1.0f - dropout_prob,
k_seq_len,
attn_batches*q_seq_len);
} else if (scalar_type == at::ScalarType::Float){
dispatch_softmax_backward<float, float, float, false, true>(
reinterpret_cast<float *>(output_grads.data_ptr()),
reinterpret_cast<const float *>(output_grads.data_ptr()),
reinterpret_cast<const float *>(softmax_results.data_ptr()),
reinterpret_cast<const void *>(dropout_mask->data_ptr()),
1.0f - dropout_prob,
k_seq_len,
attn_batches*q_seq_len);
}
} else {
if (scalar_type == at::ScalarType::BFloat16){
dispatch_softmax_backward<nv_bfloat16, nv_bfloat16, float, false, false>(
reinterpret_cast<nv_bfloat16 *>(output_grads.data_ptr()),
reinterpret_cast<nv_bfloat16 *>(output_grads.data_ptr()),
reinterpret_cast<const nv_bfloat16 *>(softmax_results.data_ptr()),
nullptr,
1.0f,
k_seq_len,
attn_batches*q_seq_len);
} else if (scalar_type == at::ScalarType::Half){
dispatch_softmax_backward<half, half, float, false, false>(
reinterpret_cast<half *>(output_grads.data_ptr()),
reinterpret_cast<half *>(output_grads.data_ptr()),
reinterpret_cast<const half *>(softmax_results.data_ptr()),
nullptr,
1.0f,
k_seq_len,
attn_batches*q_seq_len);
} else if (scalar_type == at::ScalarType::Float){
dispatch_softmax_backward<float, float, float, false, false>(
reinterpret_cast<float *>(output_grads.data_ptr()),
reinterpret_cast<float *>(output_grads.data_ptr()),
reinterpret_cast<const float *>(softmax_results.data_ptr()),
nullptr,
1.0f,
k_seq_len,
attn_batches*q_seq_len);
}
}
//backward pass is completely in-place
return output_grads;
}
#pragma once
#include <iostream>
#include <type_traits>
#include <limits>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <curand_kernel.h>
#include "util.h"
template <int N>
using IntegerBits = typename std::conditional<N <= 8, uint8_t,
typename std::conditional<N <= 16, uint16_t,
typename std::conditional<N <= 32, uint32_t,
typename std::conditional<N <= 64, uint64_t, void>::type
>::type
>::type
>::type;
template <int LogElements>
struct SoftmaxParameters {
static_assert(LogElements <= 11, "");
static constexpr int Elements = 1 << LogElements;
static constexpr int WarpBatch = Elements <= 128 ? 2 : 1;
static constexpr int WarpIterations = Elements <= 32 ? 1 : Elements / 32;
using MaskType = IntegerBits<WarpIterations>;
static constexpr int WarpSize = Elements <= 32 ? Elements : 32;
static constexpr int MaskStride = WarpSize;
};
inline int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
return log2_value;
}
inline at::ScalarType softmax_mask_dtype(int elements) {
if (elements > 1024) {
return torch::kInt64;
} else if (elements > 512) {
return torch::kInt32;
} else if (elements > 256) {
return torch::kInt16;
}
return torch::kInt8;
}
inline int softmax_mask_size(int batch_size, int elements) {
int log2_elements = log2_ceil(elements);
int e = 1 << log2_elements;
int warp_size = e < 32 ? e : 32;
return batch_size * warp_size;
}
inline int softmax_rng_delta_offset(int elements) {
int log2_elements = log2_ceil(elements);
int e = 1 << log2_elements;
int warp_iterations = e <= 32 ? 1 : e / 32;
int warp_batch = e <= 128 ? 2 : 1;
return warp_iterations * warp_batch;
}
template <
typename input_t, typename output_t, typename acc_t,
typename Parameters, bool NeedMask
>
__global__ void softmax_warp_forward(input_t *dst, input_t *dst_orig, const output_t *src,
typename Parameters::MaskType *mask, acc_t p, int batch_size, int element_count, uint64_t seed, uint64_t rand_offset) {
using MaskType = typename Parameters::MaskType;
curandStatePhilox4_32_10_t state;
int64_t first_batch = (static_cast<int64_t>(blockDim.y) * static_cast<int64_t>(blockIdx.x) + threadIdx.y) * Parameters::WarpBatch;
// there might be multiple batches per warp. compute the index within the batch
int64_t local_idx = threadIdx.x;
const int64_t thread_offset = first_batch * element_count + local_idx;
if IF_CONSTEXPR (NeedMask) {
curand_init(seed, thread_offset, rand_offset, &state);
}
// batch_size might not be a multiple of Parameters::WarpBatch. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > Parameters::WarpBatch)
local_batches = Parameters::WarpBatch;
src += thread_offset;
dst += thread_offset;
if IF_CONSTEXPR (NeedMask) {
dst_orig += thread_offset;
mask += first_batch * Parameters::MaskStride;
}
// load data from global memory
input_t elements_input[Parameters::WarpBatch][Parameters::WarpIterations];
for (int i = 0; i < Parameters::WarpBatch; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
for (int it = 0; it < Parameters::WarpIterations; ++it) {
int element_index = local_idx + it * Parameters::WarpSize;
elements_input[i][it] = -std::numeric_limits<float>::infinity();
if (element_index < batch_element_count) {
elements_input[i][it] = src[i * element_count + it * Parameters::WarpSize];
}
}
}
// convert input_t to acc_t
acc_t elements[Parameters::WarpBatch][Parameters::WarpIterations];
for (int i = 0; i < Parameters::WarpBatch; ++i) {
for (int it = 0; it < Parameters::WarpIterations; ++it) {
elements[i][it] = elements_input[i][it];
}
}
// compute local max_value
// take the max_value of the first element to avoid one max call
acc_t max_value[Parameters::WarpBatch];
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i) {
max_value[i] = elements[i][0];
}
#pragma unroll
for (int it = 1; it < Parameters::WarpIterations; ++it) {
for (int i = 0; i < Parameters::WarpBatch; ++i) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
// reduction max_value
#pragma unroll
for (int offset = Parameters::WarpSize / 2; offset > 0; offset /= 2) {
float val[Parameters::WarpBatch];
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i) {
val[i] = SHFL_XOR(max_value[i], offset, Parameters::WarpSize);
}
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i) {
max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];
}
}
// compute local sum
acc_t sum[Parameters::WarpBatch] { 0.0f };
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i) {
for (int it = 0; it < Parameters::WarpIterations; ++it) {
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
sum[i] += elements[i][it];
}
}
// reduction sum
#pragma unroll
for (int offset = Parameters::WarpSize / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i) {
sum[i] += SHFL_XOR(sum[i], offset, Parameters::WarpSize);
}
}
// store result
if IF_CONSTEXPR (NeedMask) {
const acc_t pinv = 1.0 / p;
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i) {
if (i >= local_batches)
break;
MaskType m = 0;
if IF_CONSTEXPR (Parameters::WarpIterations == 1) {
float rand = curand_uniform(&state);
m = rand < p;
} else if IF_CONSTEXPR (Parameters::WarpIterations == 2) {
m = curand_uniform(&state) < p;
m |= (curand_uniform(&state) < p) << 1;
} else {
#pragma unroll
for (int j = 0; j < DIV_CELL(Parameters::WarpIterations, 4); ++j) {
float4 rand4 = curand_uniform4(&state);
m |= (((MaskType)(rand4.x < p)) << (j * 4))
| (((MaskType)(rand4.y < p)) << (j * 4 + 1))
| (((MaskType)(rand4.z < p)) << (j * 4 + 2))
| (((MaskType)(rand4.w < p)) << (j * 4 + 3));
}
}
mask[i * Parameters::MaskStride + local_idx] = m;
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it) {
int element_index = local_idx + it * Parameters::WarpSize;
if (element_index < element_count) {
const output_t d = elements[i][it] / sum[i];
dst[i * element_count + it * Parameters::WarpSize] = (acc_t)d * ((acc_t)((m >> it) & 1) * pinv);
dst_orig[i * element_count + it * Parameters::WarpSize] = d;
}
else {
break;
}
}
}
} else {
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it) {
int element_index = local_idx + it * Parameters::WarpSize;
if (element_index < element_count) {
dst[i * element_count + it * Parameters::WarpSize] = elements[i][it] / sum[i];
}
else {
break;
}
}
}
}
}
#define LAUNCH_FORWARD_KERNEL(l) \
softmax_warp_forward<input_t, output_t, acc_t, SoftmaxParameters<l>, NeedMask> \
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
dst, dst_orig, src, (typename SoftmaxParameters<l>::MaskType *)mask, p, \
batch_count, softmax_elements, seed, offset \
); \
return true;
template<typename input_t, typename output_t, typename acc_t, bool NeedMask>
bool dispatch_softmax_forward(output_t *dst, output_t *dst_orig, const input_t *src, void *mask, acc_t p,
int softmax_elements, int batch_count, uint64_t seed, uint64_t offset)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) {
return false;
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
// This value must match the Parameters::WarpSize constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
// This value must match the Parameters::WarpBatch constexpr value computed inside softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: LAUNCH_FORWARD_KERNEL(0)
case 1: LAUNCH_FORWARD_KERNEL(1)
case 2: LAUNCH_FORWARD_KERNEL(2)
case 3: LAUNCH_FORWARD_KERNEL(3)
case 4: LAUNCH_FORWARD_KERNEL(4)
case 5: LAUNCH_FORWARD_KERNEL(5)
case 6: LAUNCH_FORWARD_KERNEL(6)
case 7: LAUNCH_FORWARD_KERNEL(7)
case 8: LAUNCH_FORWARD_KERNEL(8)
case 9: LAUNCH_FORWARD_KERNEL(9)
case 10: LAUNCH_FORWARD_KERNEL(10)
case 11: LAUNCH_FORWARD_KERNEL(11)
default: return false;
}
}
return false;
}
template <
typename input_t, typename output_t, typename acc_t, typename Parameters,
bool IsLogSoftmax, bool NeedMask
>
__global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad, const input_t *output,
const typename Parameters::MaskType *mask, acc_t p, int batch_size, int element_count)
{
using MaskType = typename Parameters::MaskType;
int64_t first_batch = (static_cast<int64_t>(blockDim.y) * static_cast<int64_t>(blockIdx.x) + threadIdx.y) * Parameters::WarpBatch;
// batch_size might not be a multiple of Parameters::WarpBatch. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
if (local_batches > Parameters::WarpBatch)
local_batches = Parameters::WarpBatch;
// there might be multiple batches per warp. compute the index within the batch
int64_t local_idx = threadIdx.x;
// the first element to process by the current thread
int64_t thread_offset = first_batch * element_count + local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
if IF_CONSTEXPR (NeedMask) {
mask += first_batch * Parameters::MaskStride;
}
// The nested loops over Parameters::WarpBatch and then Parameters::WarpIterations can be simplified to one loop,
// but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep
// the nested loops.
// This should have no impact on performance because the loops are unrolled anyway.
// load data from global memory
acc_t grad_reg[Parameters::WarpBatch][Parameters::WarpIterations];
acc_t output_reg[Parameters::WarpBatch][Parameters::WarpIterations] ;
if IF_CONSTEXPR (NeedMask) {
MaskType mask_reg[Parameters::WarpBatch];
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i) {
if (i >= local_batches)
break;
mask_reg[i] = mask[i * Parameters::MaskStride + local_idx];
}
const acc_t pinv = 1.0 / p;
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
MaskType m = mask_reg[i];
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it) {
int element_index = local_idx + it * Parameters::WarpSize;
if (element_index < batch_element_count) {
grad_reg[i][it] =
(input_t)(
(acc_t)((m >> it) & 1) *
(acc_t)grad[i * element_count + it * Parameters::WarpSize] *
pinv
) *
output[i * element_count + it * Parameters::WarpSize];
output_reg[i][it] = output[i * element_count + it * Parameters::WarpSize];
} else {
grad_reg[i][it] = acc_t(0);
output_reg[i][it] = acc_t(0);
}
}
}
} else {
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it) {
int element_index = local_idx + it * Parameters::WarpSize;
if (element_index < batch_element_count) {
grad_reg[i][it] = grad[i * element_count + it * Parameters::WarpSize] *
output[i * element_count + it * Parameters::WarpSize];
output_reg[i][it] = output[i * element_count + it * Parameters::WarpSize];
} else {
grad_reg[i][it] = acc_t(0);
output_reg[i][it] = acc_t(0);
}
}
}
}
acc_t sum[Parameters::WarpBatch];
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i) {
sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < Parameters::WarpIterations; ++it) {
sum[i] += grad_reg[i][it];
}
}
#pragma unroll
for (int offset = Parameters::WarpSize / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i) {
sum[i] += SHFL_XOR(sum[i], offset, Parameters::WarpSize);
}
}
// store result
#pragma unroll
for (int i = 0; i < Parameters::WarpBatch; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < Parameters::WarpIterations; ++it) {
int element_index = local_idx + it * Parameters::WarpSize;
if (element_index < element_count) {
// compute gradients
if IF_CONSTEXPR (IsLogSoftmax) {
gradInput[i * element_count + it * Parameters::WarpSize] =
(grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);
} else {
gradInput[i * element_count + it * Parameters::WarpSize] =
(grad_reg[i][it] - output_reg[i][it] * sum[i]);
}
}
}
}
}
#define LAUNCH_BACKWARD_KERNEL(l) \
softmax_warp_backward<input_t, output_t, acc_t, SoftmaxParameters<l>, IsLogSoftmax, NeedMask> \
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
grad_input, grad, output, (const typename SoftmaxParameters<l>::MaskType *)mask, p, \
batch_count, softmax_elements \
); \
break;
template<typename input_t, typename output_t, typename acc_t, bool IsLogSoftmax, bool NeedMask>
void dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const input_t *output,
const void *mask, acc_t p, int softmax_elements, int batch_count)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: LAUNCH_BACKWARD_KERNEL(0)
case 1: LAUNCH_BACKWARD_KERNEL(1)
case 2: LAUNCH_BACKWARD_KERNEL(2)
case 3: LAUNCH_BACKWARD_KERNEL(3)
case 4: LAUNCH_BACKWARD_KERNEL(4)
case 5: LAUNCH_BACKWARD_KERNEL(5)
case 6: LAUNCH_BACKWARD_KERNEL(6)
case 7: LAUNCH_BACKWARD_KERNEL(7)
case 8: LAUNCH_BACKWARD_KERNEL(8)
case 9: LAUNCH_BACKWARD_KERNEL(9)
case 10: LAUNCH_BACKWARD_KERNEL(10)
case 11: LAUNCH_BACKWARD_KERNEL(11)
default: break;
}
}
}
#include <ATen/ATen.h>
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_AND_BF16(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: \
{ \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes
(T *x,
T val,
int lanes=1,
bool share_result=false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y*blockDim.x;
int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32.
if(blockSize >= 64)
{
x[tid] = val;
__syncthreads();
}
#pragma unroll
for(int i = (blockSize >> 1); i >= 64; i >>= 1)
{
if(tid < i)
x[tid] = x[tid] + x[tid+i];
__syncthreads();
}
T final;
if(tid < 32)
{
if(blockSize >= 64)
final = x[tid] + x[tid+32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for(int i = 16; i >= lanes; i >>= 1)
final = final + __shfl_down_sync(0xffffffff, final, i);
}
if(share_result)
{
if(tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes_max_op
(T *x,
T val,
int lanes=1,
bool share_result=false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y*blockDim.x;
int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32.
if(blockSize >= 64)
{
x[tid] = val;
__syncthreads();
}
#pragma unroll
for(int i = (blockSize >> 1); i >= 64; i >>= 1)
{
if(tid < i)
x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid+i]));
__syncthreads();
}
T final;
if(tid < 32)
{
if(blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid+32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for(int i = 16; i >= lanes; i >>= 1)
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
}
if(share_result)
{
if(tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
#pragma once
#define DIV_CELL(a, b) (((a) + (b) - 1) / (b))
#if __cplusplus >= 201703L
#define IF_CONSTEXPR constexpr
#else
#define IF_CONSTEXPR
#endif
template <typename T>
__device__ __forceinline__ T SHFL_XOR(T value, int laneMask, int width, unsigned int mask = 0xffffffff)
{
#if CUDA_VERSION >= 9000
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
template <typename T, int N>
struct VecTypeImpl;
#define DEFINE_VEC_TYPE(t, n, tn) \
template <> \
struct VecTypeImpl<t, n> { \
using type = tn; \
};
DEFINE_VEC_TYPE(half, 1, half)
DEFINE_VEC_TYPE(__nv_bfloat16, 1, __nv_bfloat16)
DEFINE_VEC_TYPE(float, 1, float)
DEFINE_VEC_TYPE(half, 2, half2)
DEFINE_VEC_TYPE(__nv_bfloat16, 2, __nv_bfloat162)
DEFINE_VEC_TYPE(float, 2, float2)
DEFINE_VEC_TYPE(half, 4, uint64_t)
DEFINE_VEC_TYPE(__nv_bfloat16, 4, uint64_t)
DEFINE_VEC_TYPE(float, 4, float4)
template <typename T, int N>
using VecType = typename VecTypeImpl<T, N>::type;
\ No newline at end of file
# ==================================================================
# module list
# ------------------------------------------------------------------
# python 3.8 (conda)
# pytorch 1.11.0 (conda)
# apex from github
# ==================================================================
FROM nvidia/cuda:11.3.0-cudnn8-devel-ubuntu20.04
ENV LANG C.UTF-8
ENV OFED_VERSION=5.3-1.0.0.1
RUN APT_INSTALL="apt-get install -y --no-install-recommends" && \
rm -rf /var/lib/apt/lists/* \
/etc/apt/sources.list.d/cuda.list \
/etc/apt/sources.list.d/nvidia-ml.list && \
apt-get update && \
DEBIAN_FRONTEND=noninteractive $APT_INSTALL \
software-properties-common \
&& \
apt-get update && \
DEBIAN_FRONTEND=noninteractive $APT_INSTALL \
build-essential \
apt-utils \
ca-certificates \
wget \
git \
vim \
libssl-dev \
curl \
unzip \
unrar \
cmake \
net-tools \
sudo \
autotools-dev \
rsync \
jq \
openssh-server \
tmux \
screen \
htop \
pdsh \
openssh-client \
lshw \
dmidecode \
util-linux \
automake \
autoconf \
libtool \
net-tools \
pciutils \
libpci-dev \
libaio-dev \
libcap2 \
libtinfo5 \
fakeroot \
devscripts \
debhelper \
nfs-common
RUN cd /tmp && \
wget -q http://content.mellanox.com/ofed/MLNX_OFED-${OFED_VERSION}/MLNX_OFED_LINUX-${OFED_VERSION}-ubuntu20.04-x86_64.tgz && \
tar xzf MLNX_OFED_LINUX-${OFED_VERSION}-ubuntu20.04-x86_64.tgz && \
MLNX_OFED_LINUX-${OFED_VERSION}-ubuntu20.04-x86_64/mlnxofedinstall --user-space-only --without-fw-update --force --all && \
rm -rf /tmp/MLNX_OFED_LINUX-${OFED_VERSION}*
RUN cd /tmp && \
mkdir -p /usr/local/nccl-rdma-sharp-plugins && \
DEBIAN_FRONTEND=noninteractive apt install -y zlib1g-dev && \
git clone --depth=1 https://github.com/Mellanox/nccl-rdma-sharp-plugins.git && \
cd nccl-rdma-sharp-plugins && \
./autogen.sh && \
./configure --prefix=/usr/local/nccl-rdma-sharp-plugins --with-cuda=/usr/local/cuda && \
make && \
make install
# ==================================================================
# python
# ------------------------------------------------------------------
# Set timezone
RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime
ENV PATH /usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
ENV LD_LIBRARY_PATH /usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
ENV PYTHON_VERSION=3.8
RUN wget -O ~/miniconda.sh https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
chmod +x ~/miniconda.sh && \
~/miniconda.sh -b -p /opt/conda && \
rm ~/miniconda.sh
ENV PATH /opt/conda/bin:$PATH
RUN conda install -y python=3.8 && conda clean -ya
RUN conda install -y scipy scikit-learn pyyaml tensorboard tensorboardX && \
conda clean -ya
# install pytorch
RUN ldconfig
# ==================================================================
# pytorch
# ------------------------------------------------------------------
ENV TORCH_CUDA_ARCH_LIST "7.0;7.5;8.0"
RUN conda install -y numpy pyyaml scipy ipython mkl mkl-include ninja cython typing && \
conda clean -ya
RUN conda install pytorch=1.11.0 cudatoolkit=11.3 -c pytorch && \
conda clean -ya
RUN cd /tmp && \
git clone https://github.com/dptech-corp/Uni-Core && \
cd Uni-Core && \
python setup.py install &&\
rm -rf /tmp/*
RUN pip install --no-cache-dir tokenizers lmdb biopython ml-collections timeout-decorator urllib3 tree dm-tree
ENV LD_LIBRARY_PATH=/usr/local/nccl-rdma-sharp-plugins/lib:$LD_LIBRARY_PATH
ENV PATH=/usr/mpi/gcc/openmpi-4.1.0rc5/bin:$PATH
ENV LD_LIBRARY_PATH=/usr/mpi/gcc/openmpi-4.1.0rc5/lib:$LD_LIBRARY_PATH
RUN ldconfig && \
apt-get clean && \
apt-get autoremove && \
rm -rf /var/lib/apt/lists/* /tmp/* && \
conda clean -ya
import bert.task
import bert.model
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment