Unverified Commit 5c9b21d8 authored by yjk21's avatar yjk21 Committed by GitHub
Browse files

adds fmhalib (#1074)

parent e5f2f675
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <multihead_attn/philox.h>
#include <fmha.h>
#include <fmha/utils.h>
#include <fmha/smem_tile.h>
#include <fmha/gmem_tile.h>
#include <fmha/mask.h>
#include <fmha/softmax.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int FMHA_VERSION> struct BlockInfo {};
template <> struct BlockInfo<1> {
int actual_seqlen;
int bidx;
int sum_s;
int bidh;
int bidb;
template<typename Params>
__device__ BlockInfo( const Params &params,
const int bidb,
const int bidh,
const int tidx )
: bidb( bidb ), bidh( bidh ) {
// The block index.
sum_s = params.b * params.s;
actual_seqlen = params.s;
bidx = bidb * params.h + bidh;
}
__device__ bool stop_early() const {
return false;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <> struct BlockInfo<2> {
int actual_seqlen;
int bidx;
int sum_s;
int bidh;
int bidb;
template<typename Params>
__device__ BlockInfo( const Params &params,
const int bidb,
const int bidh,
const int tidx )
: bidb( bidb ), bidh( bidh ) {
// The block index.
sum_s = params.cu_seqlens[bidb];
actual_seqlen = params.cu_seqlens[bidb + 1] - sum_s;
bidx = sum_s * params.h + bidh;
}
__device__ bool stop_early() const {
return actual_seqlen == 0;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS_PER_CTA>
struct BlockInfoPadded {
template<typename Params>
__device__ BlockInfoPadded( const Params &params,
const int bidb,
const int bidh,
const int tidx )
: bidb( bidb ), bidh( bidh ), h(params.h) {
// The block index.
sum_s = params.cu_seqlens[bidb];
actual_seqlen = params.seqlens[bidb];
bidx = sum_s * params.h + bidh;
tidx_global = (bidb * params.h + bidh) * THREADS_PER_CTA + tidx;
}
__device__ bool stop_early() const {
return actual_seqlen == 0;
}
int actual_seqlen;
int bidx;
int sum_s;
int bidh;
int bidb;
int tidx_global;
int h;
};
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <cuda_runtime_api.h>
#include <cuda_fp16.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
#define FMHA_CHECK_CUDA( call ) \
do { \
cudaError_t status_ = call; \
if( status_ != cudaSuccess ) { \
fprintf( stderr, \
"CUDA error (%s:%d): %s\n", \
__FILE__, \
__LINE__, \
cudaGetErrorString( status_ ) ); \
exit( 1 ); \
} \
} while( 0 )
////////////////////////////////////////////////////////////////////////////////////////////////////
enum Data_type { DATA_TYPE_FP16, DATA_TYPE_FP32, DATA_TYPE_INT32, DATA_TYPE_INT8 };
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline void set_alpha( uint32_t &alpha, float norm, Data_type dtype ) {
if( dtype == DATA_TYPE_FP16 ) {
half x = __float2half_rn( norm );
uint16_t h = reinterpret_cast<const uint16_t &>( x );
ushort2 h2 = { h, h };
alpha = reinterpret_cast<const uint32_t &>( h2 );
} else if( dtype == DATA_TYPE_FP32 ) {
alpha = reinterpret_cast<const uint32_t &>( norm );
} else if( dtype == DATA_TYPE_INT32 ) {
int32_t inorm = static_cast<int32_t>( norm );
alpha = reinterpret_cast<const uint32_t &>( inorm );
} else {
assert( false );
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline size_t get_size_in_bytes( size_t n, Data_type dtype ) {
switch( dtype ) {
case DATA_TYPE_FP32:
return n * 4;
case DATA_TYPE_FP16:
return n * 2;
case DATA_TYPE_INT32:
return n * 4;
case DATA_TYPE_INT8:
return n;
default:
assert( false );
return 0;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
from .fmha import FMHAFun
###############################################################################
# Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
###############################################################################
import torch
import torch.nn.functional as F
import fmhalib as mha
class FMHAFun(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, seqlens, p_dropout, max_s, is_training):
context, S_dmask = mha.fwd(qkv, cu_seqlens, seqlens, p_dropout, max_s, is_training, None)
ctx.save_for_backward(qkv, S_dmask)
ctx.cu_seqlens = cu_seqlens
ctx.seqlens = seqlens
ctx.p_dropout = p_dropout
ctx.max_s = max_s
return context
@staticmethod
def backward(ctx, dout):
qkv, S_dmask = ctx.saved_tensors
dqkv, dp = mha.bwd(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.seqlens, ctx.p_dropout, ctx.max_s)
return dqkv, None, None, None, None, None, None
class FMHA(torch.nn.Module):
def __init__(self, config):
super(FMHA, self).__init__()
self.p_dropout = config.attention_probs_dropout_prob
self.h = config.num_attention_heads
self.hidden_size = config.hidden_size
self.d = self.hidden_size // self.h
assert self.d * self.h == self.hidden_size, "Invalid hidden size/num_heads"
def forward(self, qkv, cu_seqlens, seqlens, max_s, is_training=True):
ctx = FMHAFun.apply(qkv.view(-1, 3, self.h, self.d), cu_seqlens, seqlens, self.p_dropout, max_s, is_training)
return ctx.view(-1, self.hidden_size)
###############################################################################
# Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
###############################################################################
import sys
import torch
import numpy as np
import unittest
import math
import fmhalib as mha
def py_mha(qkv, amask, b, s, h, d):
qkv = qkv.view(b, s, h, 3, d)
q = qkv[:, :, :, 0, :].permute(0,2,1,3)
k = qkv[:, :, :, 1, :].permute(0,2,1,3)
v = qkv[:, :, :, 2, :].permute(0,2,1,3)
p = torch.matmul(q.float(), k.permute(0,1,3,2).float())
p_masked = p / math.sqrt(d) + (1.0 - amask) * -10000.0
s = torch.softmax(p_masked, -1).to(qkv.dtype)
ctx = torch.matmul(s, v)
ctx = ctx.permute(0,2,1,3).contiguous()
ctx.retain_grad()
return ctx
class TestFMHA(unittest.TestCase):
def run_test(self, s):
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
dtype = torch.float16
device = torch.device('cuda')
b = 32
h = 16
d = 64
slens = [s] * b
a = torch.tensor(np.array([0] + slens), dtype=torch.int32)
amask = torch.ones(b,h,s,s, dtype=dtype, device=device)
seqlens = torch.tensor(slens, dtype=torch.int32, device=device)
cu_seqlens = torch.cumsum(a, 0).to(dtype=torch.int32, device=device)
total = cu_seqlens[-1].item()
qkv = torch.randn((b,s,h,3,d), device=device, dtype=dtype)
qkv_vs = qkv.permute(0,1,3,2,4).contiguous().view(b*s, 3, h,d)
qkv.requires_grad = True
ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, seqlens, 0.0, s, True, None)
ctx = ctx.view(b,s,h,d)
ctx_ref = py_mha(qkv, amask, b,s,h,d)
self.assertTrue(torch.allclose(ctx_ref.float(), ctx.float(), atol=1e-3))
labels = torch.randn_like(ctx_ref)
diff = ctx_ref - labels
l = (diff * diff).sum() / b
l.backward()
dw = ctx_ref.grad.permute(0,2,1,3)
dw2 = dw.permute(0,2,1,3).clone().detach().contiguous()
dqkv2, _ = mha.bwd(dw2, qkv_vs, S_, cu_seqlens, seqlens, 0.0, s)
dqkv2 = dqkv2.permute(0,2,1,3).view(b,s, h,3,d)
self.assertTrue(torch.allclose(qkv.grad.float(), dqkv2.float(), atol=1e-3))
def test_128(self):
self.run_test(128)
def test_256(self):
self.run_test(256)
def test_384(self):
self.run_test(384)
def test_512(self):
self.run_test(512)
if __name__ == '__main__':
unittest.main()
......@@ -329,6 +329,48 @@ if "--fast_layer_norm" in sys.argv:
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
if "--fmha" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--fmha")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False)
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--fmha was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) < 11:
raise RuntimeError("--fmha only supported on SM80")
ext_modules.append(
CUDAExtension(name='fmhalib',
sources=[
'apex/contrib/csrc/fmha/fmha_api.cpp',
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu',
],
extra_compile_args={'cxx': ['-O3',
'-I./apex/contrib/csrc/fmha/src',
] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_80,code=sm_80',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'-I./apex/contrib/csrc/',
'-I./apex/contrib/csrc/fmha/src',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
if "--fast_multihead_attn" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment