Commit d5f3875c authored by Shenggan's avatar Shenggan
Browse files

init commit

parent 62ed1c4a
#include <iostream>
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "compat.h"
#include "softmax.cuh"
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
at::Tensor softmax(at::Tensor input, int rows, int cols) {
CHECK_INPUT(input);
at::Tensor output = at::empty_like(input);
fastfold::softmax::DirectLoad<at::BFloat16, float> load((at::BFloat16 *)input.data_ptr(),
int64_t(cols));
fastfold::softmax::DirectStore<float, at::BFloat16> store((at::BFloat16 *)output.data_ptr(),
int64_t(cols));
auto cuda_stream = at::cuda::getCurrentCUDAStream().stream();
fastfold::softmax::DispatchSoftmax<decltype(load), decltype(store), float>(cuda_stream, load,
store, rows, cols);
return output;
}
at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor input, int rows, int cols) {
CHECK_INPUT(input);
at::Tensor grad_input = at::empty_like(input);
fastfold::softmax::DirectLoad<at::BFloat16, float> load_d((at::BFloat16 *)d_output.data_ptr(),
int64_t(cols));
fastfold::softmax::DirectLoad<at::BFloat16, float> load((at::BFloat16 *)input.data_ptr(),
int64_t(cols));
fastfold::softmax::DirectStore<float, at::BFloat16> store((at::BFloat16 *)grad_input.data_ptr(),
int64_t(cols));
auto cuda_stream = at::cuda::getCurrentCUDAStream().stream();
fastfold::softmax::DispatchSoftmaxGrad<decltype(load), decltype(load_d), decltype(store),
float>(cuda_stream, load, load_d, store, rows, cols);
return grad_input;
}
at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, int rows, int cols,
float scale) {
CHECK_INPUT(input);
CHECK_INPUT(mask);
int head = input.sizes()[2];
at::Tensor output = at::empty_like(input);
// (const SRC* src, const int8_t* mask, int64_t row_size, SRC scale)
fastfold::softmax::ScaleMaskLoad<at::BFloat16, float> load((at::BFloat16 *)input.data_ptr(),
(at::BFloat16 *)mask.data_ptr(),
int64_t(cols), int64_t(head), scale);
fastfold::softmax::DirectStore<float, at::BFloat16> store((at::BFloat16 *)output.data_ptr(),
int64_t(cols));
auto cuda_stream = at::cuda::getCurrentCUDAStream().stream();
fastfold::softmax::DispatchSoftmax<decltype(load), decltype(store), float>(cuda_stream, load,
store, rows, cols);
return output;
}
at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor input, at::Tensor mask,
int rows, int cols, float scale) {
CHECK_INPUT(input);
CHECK_INPUT(mask);
int head = input.sizes()[2];
at::Tensor grad_input = at::empty_like(input);
fastfold::softmax::DirectLoad<at::BFloat16, float> load_d((at::BFloat16 *)d_output.data_ptr(),
int64_t(cols));
fastfold::softmax::DirectLoad<at::BFloat16, float> load((at::BFloat16 *)input.data_ptr(),
int64_t(cols));
// (DST* dst, const int8_t* mask, int64_t row_size, DST scale)
fastfold::softmax::ScaleMaskStore<float, at::BFloat16> store(
(at::BFloat16 *)grad_input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), int64_t(cols),
int64_t(head), scale);
auto cuda_stream = at::cuda::getCurrentCUDAStream().stream();
fastfold::softmax::DispatchSoftmaxGrad<decltype(load), decltype(load_d), decltype(store),
float>(cuda_stream, load, load_d, store, rows, cols);
return grad_input;
}
at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor mask, at::Tensor bias,
int rows, int cols, float scale) {
CHECK_INPUT(input);
CHECK_INPUT(mask);
CHECK_INPUT(bias);
int head = input.sizes()[2];
at::Tensor output = at::empty_like(input);
// (const SRC* src, const int8_t* mask, int64_t row_size, SRC scale)
fastfold::softmax::ScaleMaskBiasLoad<at::BFloat16, float> load(
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(),
(at::BFloat16 *)bias.data_ptr(), int64_t(cols), int64_t(head), scale);
fastfold::softmax::DirectStore<float, at::BFloat16> store((at::BFloat16 *)output.data_ptr(),
int64_t(cols));
auto cuda_stream = at::cuda::getCurrentCUDAStream().stream();
fastfold::softmax::DispatchSoftmax<decltype(load), decltype(store), float>(cuda_stream, load,
store, rows, cols);
return output;
}
at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tensor input,
at::Tensor mask, at::Tensor bias, int rows,
int cols, float scale) {
CHECK_INPUT(input);
CHECK_INPUT(mask);
int head = input.sizes()[2];
// CHECK_INPUT(bias);
at::Tensor grad_input = at::empty_like(input);
fastfold::softmax::DirectLoad<at::BFloat16, float> load_d((at::BFloat16 *)d_output.data_ptr(),
int64_t(cols));
fastfold::softmax::DirectLoad<at::BFloat16, float> load((at::BFloat16 *)input.data_ptr(),
int64_t(cols));
// (DST* dst, const int8_t* mask, int64_t row_size, DST scale)
fastfold::softmax::ScaleMaskStore<float, at::BFloat16> store(
(at::BFloat16 *)grad_input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), int64_t(cols),
int64_t(head), scale);
auto cuda_stream = at::cuda::getCurrentCUDAStream().stream();
fastfold::softmax::DispatchSoftmaxGrad<decltype(load), decltype(load_d), decltype(store),
float>(cuda_stream, load, load_d, store, rows, cols);
return grad_input;
}
\ No newline at end of file
#include <ATen/ATen.h>
#include "compat.h"
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) { \
case at::ScalarType::Float: { \
using scalar_t_in = float; \
switch (TYPEOUT) { \
case at::ScalarType::Float: { \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: { \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template <typename T>
__device__ __forceinline__ T
reduce_block_into_lanes(T *x, T val, int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = x[tid] + x[tid + i];
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = x[tid] + x[tid + 32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final = final + __shfl_down_sync(0xffffffff, final, i);
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
template <typename T>
__device__ __forceinline__ T
reduce_block_into_lanes_max_op(T *x, T val, int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
\ No newline at end of file
import importlib
import numbers
import torch
from torch.nn import init
from torch.nn.parameter import Parameter
global fastfold_layer_norm_cuda
fastfold_layer_norm_cuda = None
class FusedLayerNormAffineFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fastfold_layer_norm_cuda.forward_affine(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias \
= fastfold_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None
class MixedFusedLayerNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5):
super(MixedFusedLayerNorm, self).__init__()
global fastfold_layer_norm_cuda
if fastfold_layer_norm_cuda is None:
try:
fastfold_layer_norm_cuda = importlib.import_module("fastfold_layer_norm_cuda")
except ImportError:
raise RuntimeError('MixedFusedLayerNorm requires cuda extensions')
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters()
def reset_parameters(self):
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, input):
return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias,
self.normalized_shape, self.eps)
import importlib
from functools import reduce
from operator import mul
import torch
fastfold_softmax_cuda = importlib.import_module("fastfold_softmax_cuda")
class SoftmaxAffineFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
input_ = input.contiguous()
ctx.cols = input_.shape[-1]
ctx.rows = reduce(mul, input.shape[:-1])
output = fastfold_softmax_cuda.forward_affine(input_, ctx.rows, ctx.cols)
ctx.save_for_backward(output)
return output
@staticmethod
def backward(ctx, grad_output):
output = ctx.saved_tensors[0]
grad_input = None
grad_input = fastfold_softmax_cuda.backward_affine(grad_output.contiguous(), output,
ctx.rows, ctx.cols)
return grad_input
class FusedScaleMaskSoftmaxFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mask, scale):
input_ = input.contiguous()
mask_ = mask.contiguous()
ctx.cols = input_.shape[-1]
ctx.rows = reduce(mul, input.shape[:-1])
output = fastfold_softmax_cuda.fused_scale_mask_softmax_forward(
input_, mask_, ctx.rows, ctx.cols, scale)
ctx.save_for_backward(output, mask_)
ctx.scale = scale
return output
@staticmethod
def backward(ctx, grad_output):
output, mask_ = ctx.saved_tensors
grad_input = None
grad_input = fastfold_softmax_cuda.fused_scale_mask_softmax_backward(
grad_output.contiguous(), output, mask_, ctx.rows, ctx.cols, ctx.scale)
return grad_input.contiguous(), None, None
class FusedScaleMaskBiasSoftmaxFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mask, bias, scale):
input_ = input.contiguous()
mask_ = mask.contiguous()
bias_ = bias.contiguous()
ctx.cols = input_.shape[-1]
ctx.rows = reduce(mul, input.shape[:-1])
output = fastfold_softmax_cuda.fused_scale_mask_bias_softmax_forward(
input_, mask_, bias_, ctx.rows, ctx.cols, scale)
ctx.save_for_backward(output, mask_, bias_)
ctx.scale = scale
return output
@staticmethod
def backward(ctx, grad_output):
output, mask_, bias_ = ctx.saved_tensors
grad_input = None
grad_input = fastfold_softmax_cuda.fused_scale_mask_bias_softmax_backward(
grad_output.contiguous(), output, mask_, bias_, ctx.rows, ctx.cols, ctx.scale)
grad_input = grad_input.contiguous()
grad_bias = torch.sum(grad_input, dim=1, keepdim=True)
return grad_input.contiguous(), grad_bias, None, None
softmax = SoftmaxAffineFunction.apply
scale_mask_softmax = FusedScaleMaskSoftmaxFunction.apply
scale_mask_bias_softmax = FusedScaleMaskBiasSoftmaxFunction.apply
from .options import _set_jit_fusion_options
_set_jit_fusion_options()
\ No newline at end of file
import torch
import torch.nn.functional as F
@torch.jit.script
def bias_sigmod_ele(y, bias, z):
return torch.sigmoid(y + bias) * z
@torch.jit.script
def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor,
residual: torch.Tensor, prob: float) -> torch.Tensor:
out = (x + bias) * F.dropout(dropmask, p=prob, training=True)
out = residual + out
return out
@torch.jit.script
def bias_ele_dropout_residual(ab: torch.Tensor, b: torch.Tensor, g: torch.Tensor,
dropout_mask: torch.Tensor, Z_raw: torch.Tensor,
prob: float) -> torch.Tensor:
return Z_raw + F.dropout(dropout_mask, p=prob, training=True) * (g * (ab + b))
import torch
JIT_OPTIONS_SET = False
def _set_jit_fusion_options():
"""Set PyTorch JIT layer fusion options."""
global JIT_OPTIONS_SET
if JIT_OPTIONS_SET == False:
# flags required to enable jit fusion kernels
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
# if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
# # nvfuser
# torch._C._jit_set_profiling_executor(True)
# torch._C._jit_set_profiling_mode(True)
# torch._C._jit_override_can_fuse_on_cpu(False)
# torch._C._jit_override_can_fuse_on_gpu(False)
# torch._C._jit_set_texpr_fuser_enabled(False)
# torch._C._jit_set_nvfuser_enabled(True)
# torch._C._debug_set_autodiff_subgraph_inlining(False)
# else:
# legacy pytorch fuser
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
JIT_OPTIONS_SET = True
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
einops
colossalai
\ No newline at end of file
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment