Commit bfba68d8 authored by Jun Ru Anderson's avatar Jun Ru Anderson Committed by Mandeep Singh Baines
Browse files

[feat] add FusedAdam (#10)



Add FusedAdam, update benchmark and add tests.
Co-authored-by: default avatarJun Ru Anderson <andersonic@fb.com>
parent 4f7d7d34
......@@ -47,12 +47,19 @@ install_dep: &install_dep
python -c 'import torch; print("Torch version:", torch.__version__)'
python -m torch.utils.collect_env
install_repo: &install_repo
install_repo_cpu: &install_repo_cpu
- run:
name: Install Repository
command: |
python setup.py build develop
install_repo_gpu: &install_repo_gpu
- run:
name: Install Repository
command: |
export CUDA_HOME=/usr/local/cuda-10.1
python setup.py build develop
run_unittests: &run_unittests
- run:
name: Run Unit Tests
......@@ -97,7 +104,7 @@ jobs:
- ~/venv
key: cache-key-cpu-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_repo
- <<: *install_repo_cpu
- run:
name: Run Linter (isort)
......@@ -147,7 +154,7 @@ jobs:
- ~/venv
key: cache-key-gpu-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_repo
- <<: *install_repo_gpu
- <<: *run_unittests
......@@ -180,7 +187,7 @@ jobs:
- ~/venv
key: cache-key-gpu-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_repo
- <<: *install_repo_gpu
- <<: *run_transformer_benchmark
......
......@@ -51,6 +51,7 @@ isort
flake8
```
* Read the [editorconfig](.editorconfig) file to understand the exact coding style preferences.
* Place Python code related to models in fairscale/nn. Place Python code related to optimizers in fairscale/optim. Place C++ extensions in fairscale/clib.
## Testing
......
......@@ -5,12 +5,20 @@ import time
import torch
import torch.nn as nn
from torch.optim.adam import Adam
import torchtext
from torchtext.data.utils import get_tokenizer
import fairscale.nn.pipe.pipe as pipe
try:
from fairscale.optim.adam import Adam # type: ignore
can_benchmark = True
except ImportError:
from torch.optim import Adam # type: ignore
can_benchmark = False
class EmbeddingLayer(nn.Embedding):
def __init__(self, ntoken, ninp, initrange):
......@@ -210,10 +218,10 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion,
)
print("=" * 89)
if len(model.balance) == 4:
if can_benchmark and len(model.balance) == 4:
# Assert that words per second is within 3 standard deviations of the average
# of five golden runs
assert wps > 19276.1 - (3 * 88)
# of six golden runs
assert wps > 20052.1 - (3 * 359)
print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"]))
print("Peak allocated bytes on cuda:1: {:1d}".format(torch.cuda.memory_stats(1)["allocated_bytes.all.peak"]))
......@@ -222,7 +230,7 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion,
# Assert that memory usage on each GPU is within 10% of golden run
# Right-hand-side is golden run bytes * 110%
assert torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] < 365915648 * 1.1
assert torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] < 365916160 * 1.1
assert torch.cuda.memory_stats(1)["allocated_bytes.all.peak"] < 1281024 * 1.1
assert torch.cuda.memory_stats(2)["allocated_bytes.all.peak"] < 2788864 * 1.1
assert torch.cuda.memory_stats(3)["allocated_bytes.all.peak"] < 190724608 * 1.1
......
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#define DATA_PTR data_ptr
#include <torch/extension.h>
// CUDA forward declaration
void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
// C++ interface
void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {
CHECK_INPUT(p);
if (p_copy.numel() > 0) CHECK_INPUT(p_copy);
CHECK_INPUT(m);
CHECK_INPUT(v);
CHECK_INPUT(g);
int64_t num_elem = p.numel();
AT_ASSERTM(m.numel() == num_elem, "number of elements in m and p tensors should be equal");
AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal");
AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal");
AT_ASSERTM(p_copy.numel() == num_elem || p_copy.numel() == 0, "number of elements in p_copy and p tensors should be equal, or p_copy should be empty");
fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("adam", &adam, "Adam optimized CUDA implementation.");
m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.");
}
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/detail/IndexUtils.cuh"
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <cmath>
#include "ATen/TensorUtils.h"
// #include "ATen/Type.h"
#include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h>
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
#include "type_shim.h"
typedef enum{
ADAM_MODE_0 =0, // eps under square root
ADAM_MODE_1 =1 // eps outside square root
} adamMode_t;
template <typename T, typename GRAD_T>
__global__ void adam_cuda_kernel(
GRAD_T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T * __restrict__ g,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay)
{
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
for (int j = i; j < tsize; j+=totThreads) {
T scaled_grad = g[j]/grad_scale;
m[j] = b1*m[j] + (1-b1)*scaled_grad;
v[j] = b2*v[j] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
float update = (m[j]/denom) + (decay*p[j]);
p[j] = (GRAD_T)((float)p[j] - (step_size*update));
if (p_copy != NULL) p_copy[j] = (GRAD_T) p[j];
}
}
template <int DEPTH, typename T, typename GRAD_T>
struct AdamFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<DEPTH>& tl,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
adamMode_t mode,
const float decay)
{
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
GRAD_T* p = (GRAD_T *)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T *)tl.addresses[1][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T *)tl.addresses[2][tensor_loc];
v += chunk_idx*chunk_size;
GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];
g += chunk_idx*chunk_size;
GRAD_T* p_copy = NULL;
if (DEPTH == 5) {
p_copy = (GRAD_T *)tl.addresses[4][tensor_loc];
p_copy += chunk_idx*chunk_size;
}
n -= chunk_idx*chunk_size;
T incoming_p[ILP];
T incoming_m[ILP];
T incoming_v[ILP];
T incoming_g[ILP];
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
incoming_p[ii] = 0;
incoming_m[ii] = 0;
incoming_v[ii] = 0;
incoming_g[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if (i < n && i < chunk_size) {
incoming_p[ii] = static_cast<T>(p[i]);
incoming_m[ii] = m[i];
incoming_v[ii] = v[i];
incoming_g[ii] = static_cast<T>(g[i]);
}
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = i_start + threadIdx.x + ii*blockDim.x;
if(j < n && j < chunk_size) {
T scaled_grad = incoming_g[ii]/grad_scale;
m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
float update = (m[j]/denom) + (decay*incoming_p[ii]);
p[j] = (GRAD_T)(incoming_p[ii] - (step_size*update));
if (DEPTH == 5) p_copy[j] = p[j];
}
}
}
}
};
void fused_adam_cuda(
at::Tensor & p,
at::Tensor & p_copy,
at::Tensor & m,
at::Tensor & v,
at::Tensor & g,
float lr,
float beta1,
float beta2,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay)
{
// using namespace at;
//Get tensor size
int tsize = p.numel();
//Determine #threads and #blocks
const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.scalar_type() == at::ScalarType::Half) {
//all other values should be fp32 for half gradients
AT_ASSERTM(p.scalar_type() == at::ScalarType::Half, "expected parameter to be of half type");
//dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<scalar_t_0>(),
p_copy.numel() ? p_copy.DATA_PTR<scalar_t_0>() : NULL,
m.DATA_PTR<accscalar_t>(),
v.DATA_PTR<accscalar_t>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
} else {
using namespace at;
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
adam_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<scalar_t_0>(),
NULL, //don't output p_copy for fp32, it's wasted write
m.DATA_PTR<scalar_t_0>(),
v.DATA_PTR<scalar_t_0>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
decay);
);
}
THCudaCheck(cudaGetLastError());
}
void fused_adam_cuda_mt(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy
float lr,
float beta1,
float beta2,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay) {
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
size_t tl_sz = tensor_lists.size();
AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5");
if (tensor_lists[3][0].scalar_type() == at::ScalarType::Half) {
//alher values should be fp32 for half gradients
AT_ASSERTM(tensor_lists[0][0].scalar_type() == at::ScalarType::Half, "expected parameter to be of float type");
//dich is done on the gradient type
if (tl_sz == 5) {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<5, accscalar_t, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
} else {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, accscalar_t, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
}
} else {
if (tl_sz == 5) {
DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<5, scalar_t_0, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
} else {
DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, scalar_t_0, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
}
}
THCudaCheck(cudaGetLastError());
}
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include "compat.h"
#include <assert.h>
// #include <iostream>
// This header is the one-stop shop for all your multi-tensor apply needs.
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
template<int n> struct TensorListMetadata
{
void* addresses[n][depth_to_max_tensors[n-1]];
int sizes[depth_to_max_tensors[n-1]];
unsigned char block_to_tensor[depth_to_max_blocks[n-1]];
int block_to_chunk[depth_to_max_blocks[n-1]]; // I fear this needs to be a full int.
int start_tensor_this_launch;
};
template<typename T, typename U, typename... ArgTypes>
__global__ void multi_tensor_apply_kernel(
int chunk_size,
volatile int* noop_flag,
T tl,
U callable,
ArgTypes... args)
{
// Hand the chunk information to the user-supplied functor to process however it likes.
callable(chunk_size, noop_flag, tl, args...);
}
template<int depth, typename T, typename... ArgTypes>
void multi_tensor_apply(
int block_size,
int chunk_size,
const at::Tensor& noop_flag,
const std::vector<std::vector<at::Tensor>>& tensor_lists,
T callable,
ArgTypes... args)
{
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
int len0 = tensor_lists[0].size();
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
for(int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
{
TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
for(int t = 0; t < tensor_lists[l].size(); t++)
{
// TODO: Print which tensor fails.
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
#ifdef VERSION_GE_1_5
contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
#endif
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
TORCH_CHECK(tensor_lists[l][t].is_cuda(), "A tensor was not cuda.");
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
}
}
int ntensors = tensor_lists[0].size();
TensorListMetadata<depth> tl;
auto stream = at::cuda::getCurrentCUDAStream();
tl.start_tensor_this_launch = 0;
int loc_block_info = 0;
int loc_tensor_info = 0;
for(int t = 0; t < ntensors; t++)
{
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
for(int d = 0; d < depth; d++)
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
loc_tensor_info++;
int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;
for(int chunk = 0; chunk < chunks_this_tensor; chunk++)
{
// std::cout << chunks_this_tensor << std::endl;
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tl.block_to_chunk[loc_block_info] = chunk;
loc_block_info++;
bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth-1] &&
chunk == chunks_this_tensor - 1);
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth-1]);
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
if(tensors_full || blocks_full || last_chunk)
{
// using accscalar_t = acc_type<scalar_t, true>;
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
chunk_size,
noop_flag.DATA_PTR<int>(),
tl,
callable,
args...);
AT_CUDA_CHECK(cudaGetLastError());
// Reset. The control flow possibilities here make my brain hurt.
loc_block_info = 0;
if(chunk == chunks_this_tensor - 1)
{
// std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
loc_tensor_info = 0;
tl.start_tensor_this_launch = t + 1;
}
else
{
// std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
tl.sizes[0] = tl.sizes[loc_tensor_info-1];
for(int d = 0; d < depth; d++)
tl.addresses[d][0] = tl.addresses[d][loc_tensor_info-1];
loc_tensor_info = 1;
tl.start_tensor_this_launch = t;
}
}
}
}
}
#include <ATen/ATen.h>
#include "compat.h"
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes
(T *x,
T val,
int lanes=1,
bool share_result=false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y*blockDim.x;
int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32.
if(blockSize >= 64)
{
x[tid] = val;
__syncthreads();
}
#pragma unroll
for(int i = (blockSize >> 1); i >= 64; i >>= 1)
{
if(tid < i)
x[tid] = x[tid] + x[tid+i];
__syncthreads();
}
T final;
if(tid < 32)
{
if(blockSize >= 64)
final = x[tid] + x[tid+32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for(int i = 16; i >= lanes; i >>= 1)
final = final + __shfl_down_sync(0xffffffff, final, i);
}
if(share_result)
{
if(tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes_max_op
(T *x,
T val,
int lanes=1,
bool share_result=false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y*blockDim.x;
int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32.
if(blockSize >= 64)
{
x[tid] = val;
__syncthreads();
}
#pragma unroll
for(int i = (blockSize >> 1); i >= 64; i >>= 1)
{
if(tid < i)
x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid+i]));
__syncthreads();
}
T final;
if(tid < 32)
{
if(blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid+32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for(int i = 16; i >= lanes; i >>= 1)
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
}
if(share_result)
{
if(tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
import torch
if TYPE_CHECKING:
from torch.optim.optimizer import _params_t
else:
_params_t = Any
try:
from fairscale import fused_adam_cuda # type: ignore
class Adam(torch.optim.Optimizer):
state: dict
"""
Implements Adam algorithm. Currently GPU-only.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Compared to the original version in Apex, the fairseq version casts grads
and params to FP32 internally to support ``--memory-efficient-fp16``.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
.. _Adam: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(
self,
params: _params_t,
lr: Optional[float] = 1e-3,
bias_correction: Optional[bool] = True,
betas: Optional[Tuple[float, float]] = (0.9, 0.999),
eps: Optional[float] = 1e-8,
eps_inside_sqrt: Optional[bool] = False,
weight_decay: Optional[float] = 0.0,
max_grad_norm: Optional[float] = 0.0,
amsgrad: Optional[bool] = False,
use_mt: Optional[bool] = True,
):
self._use_multi_tensor = False
if use_mt:
self._use_multi_tensor = True
self._overflow_buf = torch.cuda.IntTensor([0]) # type: ignore
if amsgrad:
raise RuntimeError("FusedAdam does not support the AMSGrad variant.")
defaults = {
"lr": lr,
"bias_correction": bias_correction,
"betas": betas,
"eps": eps,
"weight_decay": weight_decay,
"max_grad_norm": max_grad_norm,
}
super().__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1
@property
def supports_memory_efficient_fp16(self) -> bool:
return True
def step(self, closure: Optional[Callable[[], float]] = None, scale: Optional[float] = 1.0) -> Optional[float]:
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
grads (list of tensors, optional): weight gradient to use for the
optimizer update. If gradients have type torch.half, parameters
are expected to be in type torch.float. (default: None)
output params (list of tensors, optional): A reduced precision copy
of the updated weights written out in addition to the regular
updated weights. Have to be of same type as gradients. (default: None)
scale (float, optional): factor to divide gradient tensor values
by before applying to weights. (default: 1)
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
bias_correction = 1 if group["bias_correction"] else 0
tensorlists: Dict[torch.device, List[List[torch.Tensor]]] = dict()
for p in group["params"]:
# note: p.grad should not ever be set for correct
# operation of mixed precision optimizer that sometimes
# sends None gradients
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
"FusedAdam does not support sparse gradients, " "please consider SparseAdam instead"
)
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p, dtype=torch.float32)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p, dtype=torch.float32)
exp_avg = state["exp_avg"]
exp_avg_sq = state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
out_p = torch.tensor([])
if self._use_multi_tensor:
pl = [p.data, exp_avg, exp_avg_sq, grad]
if p.device not in tensorlists:
tensorlists[p.device] = [[], [], [], []]
for tl, t in zip(tensorlists[p.device], pl):
tl.append(t)
else:
with torch.cuda.device(p.device):
fused_adam_cuda.adam(
p.data,
out_p,
exp_avg,
exp_avg_sq,
grad,
group["lr"],
beta1,
beta2,
group["eps"],
scale,
state["step"],
self.eps_mode,
bias_correction,
group["weight_decay"],
)
if self._use_multi_tensor:
for tensordevice, tensorlist in tensorlists.items():
with torch.cuda.device(tensordevice):
fused_adam_cuda.adam_mt(
2048 * 32,
self._overflow_buf,
tensorlist,
group["lr"],
beta1,
beta2,
group["eps"],
scale,
state["step"],
self.eps_mode,
bias_correction,
group["weight_decay"],
)
return loss
except ImportError:
pass
......@@ -2,7 +2,12 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import os
import warnings
import setuptools
import torch
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension
def fetch_requirements():
......@@ -11,6 +16,29 @@ def fetch_requirements():
return reqs
extensions = []
cmdclass = {}
force_cuda = os.getenv("FORCE_CUDA", "0") == "1"
if (torch.cuda.is_available() and CUDA_HOME is not None) or force_cuda:
extensions.extend(
[
CUDAExtension(
name="fairscale.fused_adam_cuda",
sources=[
"fairscale/clib/fused_adam_cuda/fused_adam_cuda.cpp",
"fairscale/clib/fused_adam_cuda/fused_adam_cuda_kernel.cu",
],
extra_compile_args={"cxx": ["-O3"], "nvcc": ["-O3", "--use_fast_math"]},
)
]
)
cmdclass["build_ext"] = BuildExtension
else:
warnings.warn("Cannot install FusedAdam cuda.")
if __name__ == "__main__":
setuptools.setup(
name="fairscale",
......@@ -18,6 +46,8 @@ if __name__ == "__main__":
install_requires=fetch_requirements(),
include_package_data=True,
packages=setuptools.find_packages(exclude=("tests", "tests.*")),
ext_modules=extensions,
cmdclass=cmdclass,
python_requires=">=3.6",
author="Facebook AI Research",
author_email="todo@fb.com",
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from copy import deepcopy
import functools
import pytest
import torch
try:
from fairscale.optim.adam import Adam
imported_adam = True
except ImportError:
imported_adam = False
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
skip_if_no_adam = pytest.mark.skipif(not imported_adam, reason="Fairscale Adam not available")
@skip_if_no_cuda
@skip_if_no_adam
def test_step():
weight = torch.randn(10, 5).cuda().requires_grad_()
bias = torch.randn(10).cuda().requires_grad_()
input = torch.randn(5).cuda()
optimizer = Adam([weight, bias], lr=1e-3)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
def fn():
optimizer.zero_grad()
y = weight.mv(input)
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
y = y.cuda(bias.get_device())
loss = (y + bias).pow(2).sum()
loss.backward()
return loss
initial_value = fn().item()
for _i in range(5):
optimizer.step(fn)
assert fn().item() < initial_value
@skip_if_no_cuda
@skip_if_no_adam
def test_step_multigpu():
if not torch.cuda.device_count() > 1:
return
weight = torch.randn(10, 5).cuda(0).requires_grad_()
bias = torch.randn(10).cuda(1).requires_grad_()
input = torch.randn(5).cuda(0)
optimizer = Adam([weight, bias], lr=1e-3)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
def fn():
optimizer.zero_grad()
y = weight.mv(input)
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
y = y.cuda(bias.get_device())
loss = (y + bias).pow(2).sum()
loss.backward()
return loss
initial_value = fn().item()
for _i in range(5):
optimizer.step(fn)
assert fn().item() < initial_value
@skip_if_no_cuda
@skip_if_no_adam
def test_state_dict():
weight = torch.randn(10, 5).float().cuda().requires_grad_()
bias = torch.randn(10).float().cuda().requires_grad_()
input = torch.randn(5).float().cuda()
optimizer = Adam([weight, bias], lr=1e-3)
def fn_base(optimizer, weight, bias, input):
optimizer.zero_grad()
loss = (weight.mv(input) + bias).pow(2).sum()
loss.backward()
return loss
fn = functools.partial(fn_base, optimizer, weight, bias, input)
# Prime the optimizer
for _i in range(5):
optimizer.step(fn)
# Clone the weights and construct new optimizer for them
weight_c = weight.data.clone().requires_grad_()
bias_c = bias.data.clone().requires_grad_()
optimizer_c = Adam([weight_c, bias_c], lr=1e-3)
fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c, input)
# Load state dict
state_dict = deepcopy(optimizer.state_dict())
state_dict_c = deepcopy(optimizer.state_dict())
optimizer_c.load_state_dict(state_dict_c)
# Run both optimizations in parallel
for _i in range(5):
optimizer.step(fn)
optimizer_c.step(fn_c)
assert torch.equal(weight, weight_c)
assert torch.equal(bias, bias_c)
@skip_if_no_cuda
@skip_if_no_adam
def test_invalid_beta():
weight = torch.randn(10, 5, requires_grad=True).float().cuda()
bias = torch.randn(10, requires_grad=True).float().cuda()
with pytest.raises(ValueError):
Adam([weight, bias], lr=1e-2, betas=(1.0, 0.0))
@skip_if_no_cuda
@skip_if_no_adam
def test_invalid_weight_decay():
weight = torch.randn(10, 5, requires_grad=True).float().cuda()
bias = torch.randn(10, requires_grad=True).float().cuda()
with pytest.raises(ValueError):
Adam([weight, bias], lr=1e-2, weight_decay=-1)
@skip_if_no_cuda
@skip_if_no_adam
def test_amsgrad():
weight = torch.randn(10, 5, requires_grad=True).float().cuda()
bias = torch.randn(10, requires_grad=True).float().cuda()
with pytest.raises(RuntimeError):
Adam([weight, bias], lr=1e-2, amsgrad=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment