Unverified Commit 7e1c22d0 authored by chochowski's avatar chochowski Committed by GitHub
Browse files

contrib/fmha: Add option to zero out tensors before math (#1322)



* extend api to allow forced memory zeroing (empty() does not do it)

* typo fix

* ctx change

* move zeroing flag to ctx

* update test
Co-authored-by: default avatarmchochowski <mchochowski@nvidia.com>
Co-authored-by: default avatarMasaki Kozuki <mkozuki@nvidia.com>
parent 44c30436
......@@ -89,6 +89,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
const float p_dropout,
const int max_seq_len,
const bool is_training,
const bool zero_tensors,
c10::optional<at::Generator> gen_) {
auto dprops = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dprops->major == 8 && dprops->minor == 0);
......@@ -147,6 +148,11 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts);
if( zero_tensors ) {
ctx.zero_();
s.zero_();
}
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
......@@ -185,7 +191,8 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
const at::Tensor &cu_seqlens, // b+1
const float p_dropout, // probability to drop
const int max_seq_len // max sequence length to choose the kernel
const int max_seq_len, // max sequence length to choose the kernel
const bool zero_tensors
) {
auto dprops = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dprops->major == 8 && dprops->minor == 0);
......@@ -235,6 +242,10 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
auto dqkv = torch::empty_like(qkv);
if( zero_tensors ) {
dqkv.zero_();
}
Fused_multihead_attention_fprop_params params;
set_params(params,
......@@ -264,6 +275,7 @@ std::vector<at::Tensor> mha_fwd_nl(const at::Tensor &qkv, // total x num
const float p_dropout,
const int max_seq_len,
const bool is_training,
const bool zero_tensors,
c10::optional<at::Generator> gen_) {
int seq_len = 512;
auto launch = &run_fmha_fp16_512_64_sm80_nl;
......@@ -304,6 +316,11 @@ std::vector<at::Tensor> mha_fwd_nl(const at::Tensor &qkv, // total x num
auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts);
if( zero_tensors ) {
ctx.zero_();
s.zero_();
}
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());
Fused_multihead_attention_fprop_params params;
......@@ -344,7 +361,8 @@ std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
const at::Tensor &cu_seqlens, // b+1
const float p_dropout, // probability to drop
const int max_seq_len // max sequence length to choose the kernel
const int max_seq_len, // max sequence length to choose the kernel
const bool zero_tensors
) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
......@@ -378,6 +396,10 @@ std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num
auto dqkv = torch::empty_like(qkv);
if( zero_tensors ) {
dqkv.zero_();
}
int num_chunks = 2;
if( batch_size == 1 ) {
num_chunks = 4;
......
......@@ -32,16 +32,17 @@ import fmhalib as mha
class FMHAFun(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, p_dropout, max_s, is_training):
def forward(ctx, qkv, cu_seqlens, p_dropout, max_s, is_training, zero_tensors):
batch_size = cu_seqlens.numel() - 1
if batch_size < 4:
context, S_dmask = mha.fwd_nl(qkv, cu_seqlens, p_dropout, max_s, is_training, None)
context, S_dmask = mha.fwd_nl(qkv, cu_seqlens, p_dropout, max_s, is_training, zero_tensors, None)
else:
context, S_dmask = mha.fwd(qkv, cu_seqlens, p_dropout, max_s, is_training, None)
context, S_dmask = mha.fwd(qkv, cu_seqlens, p_dropout, max_s, is_training, zero_tensors, None)
ctx.save_for_backward(qkv, S_dmask)
ctx.cu_seqlens = cu_seqlens
ctx.p_dropout = p_dropout
ctx.max_s = max_s
ctx.zero_tensors = zero_tensors
return context
@staticmethod
......@@ -49,11 +50,11 @@ class FMHAFun(torch.autograd.Function):
qkv, S_dmask = ctx.saved_tensors
batch_size = ctx.cu_seqlens.numel() - 1
if batch_size < 4:
dqkv, dp, _ = mha.bwd_nl(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s)
dqkv, dp, _ = mha.bwd_nl(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s, ctx.zero_tensors)
else:
dqkv, dp = mha.bwd(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s)
dqkv, dp = mha.bwd(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s, ctx.zero_tensors)
return dqkv, None, None, None, None, None, None
return dqkv, None, None, None, None, None, None, None
class FMHA(torch.nn.Module):
......@@ -67,8 +68,8 @@ class FMHA(torch.nn.Module):
self.d = self.hidden_size // self.h
assert self.d * self.h == self.hidden_size, "Invalid hidden size/num_heads"
def forward(self, qkv, cu_seqlens, max_s, is_training=True):
def forward(self, qkv, cu_seqlens, max_s, is_training=True, zero_tensors=False):
ctx = FMHAFun.apply(qkv.view(-1, 3, self.h, self.d), cu_seqlens, self.p_dropout, max_s, is_training)
ctx = FMHAFun.apply(qkv.view(-1, 3, self.h, self.d), cu_seqlens, self.p_dropout, max_s, is_training, zero_tensors)
return ctx.view(-1, self.hidden_size)
......@@ -51,8 +51,8 @@ def py_mha(qkv, amask, b, s, h, d):
class TestFMHA(unittest.TestCase):
def run_test(self, s, b):
print(f'Test s={s} b={b}')
def run_test(self, s: int, b: int, zero_tensors: bool):
print(f'Test s={s} b={b}, zero_tensors={zero_tensors}')
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
......@@ -77,9 +77,9 @@ class TestFMHA(unittest.TestCase):
qkv.requires_grad = True
if b < 4:
ctx, S_ = mha.fwd_nl(qkv_vs, cu_seqlens, 0.0, s, True, None)
ctx, S_ = mha.fwd_nl(qkv_vs, cu_seqlens, 0.0, s, True, zero_tensors, None)
else:
ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, 0.0, s, True, None)
ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, 0.0, s, True, zero_tensors, None)
ctx = ctx.view(b,s,h,d)
ctx_ref = py_mha(qkv, amask, b,s,h,d)
......@@ -95,27 +95,34 @@ class TestFMHA(unittest.TestCase):
dw2 = dw.permute(0,2,1,3).clone().detach().contiguous()
if b < 4:
dqkv2, _, _ = mha.bwd_nl(dw2, qkv_vs, S_, cu_seqlens, 0.0, s)
dqkv2, _, _ = mha.bwd_nl(dw2, qkv_vs, S_, cu_seqlens, 0.0, s, zero_tensors)
else:
dqkv2, _ = mha.bwd(dw2, qkv_vs, S_, cu_seqlens, 0.0, s)
dqkv2, _ = mha.bwd(dw2, qkv_vs, S_, cu_seqlens, 0.0, s, zero_tensors)
dqkv2 = dqkv2.permute(0,2,1,3).view(b,s, h,3,d)
self.assertTrue(torch.allclose(qkv.grad.float(), dqkv2.float(), atol=1e-3))
def test_128(self):
self.run_test(128, 32)
self.run_test(128, 32, False)
self.run_test(128, 32, True)
def test_256(self):
self.run_test(256, 32)
self.run_test(256, 32, False)
self.run_test(256, 32, True)
def test_384(self):
self.run_test(384, 32)
self.run_test(384, 32, False)
self.run_test(384, 32, True)
def test_512(self):
self.run_test(512, 32)
self.run_test(512, 2)
self.run_test(512, 3)
self.run_test(512, 32, False)
self.run_test(512, 32, True)
self.run_test(512, 2, False)
self.run_test(512, 2, True)
self.run_test(512, 3, False)
self.run_test(512, 3, True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment