Unverified Commit 7e1c22d0 authored by chochowski's avatar chochowski Committed by GitHub
Browse files

contrib/fmha: Add option to zero out tensors before math (#1322)



* extend api to allow forced memory zeroing (empty() does not do it)

* typo fix

* ctx change

* move zeroing flag to ctx

* update test
Co-authored-by: default avatarmchochowski <mchochowski@nvidia.com>
Co-authored-by: default avatarMasaki Kozuki <mkozuki@nvidia.com>
parent 44c30436
...@@ -89,6 +89,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \ ...@@ -89,6 +89,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
const float p_dropout, const float p_dropout,
const int max_seq_len, const int max_seq_len,
const bool is_training, const bool is_training,
const bool zero_tensors,
c10::optional<at::Generator> gen_) { c10::optional<at::Generator> gen_) {
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dprops->major == 8 && dprops->minor == 0); TORCH_CHECK(dprops->major == 8 && dprops->minor == 0);
...@@ -147,6 +148,11 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \ ...@@ -147,6 +148,11 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts); auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts);
if( zero_tensors ) {
ctx.zero_();
s.zero_();
}
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator()); gen_, at::cuda::detail::getDefaultCUDAGenerator());
...@@ -185,7 +191,8 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size ...@@ -185,7 +191,8 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
const at::Tensor &cu_seqlens, // b+1 const at::Tensor &cu_seqlens, // b+1
const float p_dropout, // probability to drop const float p_dropout, // probability to drop
const int max_seq_len // max sequence length to choose the kernel const int max_seq_len, // max sequence length to choose the kernel
const bool zero_tensors
) { ) {
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dprops->major == 8 && dprops->minor == 0); TORCH_CHECK(dprops->major == 8 && dprops->minor == 0);
...@@ -235,6 +242,10 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size ...@@ -235,6 +242,10 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
auto dqkv = torch::empty_like(qkv); auto dqkv = torch::empty_like(qkv);
if( zero_tensors ) {
dqkv.zero_();
}
Fused_multihead_attention_fprop_params params; Fused_multihead_attention_fprop_params params;
set_params(params, set_params(params,
...@@ -264,6 +275,7 @@ std::vector<at::Tensor> mha_fwd_nl(const at::Tensor &qkv, // total x num ...@@ -264,6 +275,7 @@ std::vector<at::Tensor> mha_fwd_nl(const at::Tensor &qkv, // total x num
const float p_dropout, const float p_dropout,
const int max_seq_len, const int max_seq_len,
const bool is_training, const bool is_training,
const bool zero_tensors,
c10::optional<at::Generator> gen_) { c10::optional<at::Generator> gen_) {
int seq_len = 512; int seq_len = 512;
auto launch = &run_fmha_fp16_512_64_sm80_nl; auto launch = &run_fmha_fp16_512_64_sm80_nl;
...@@ -304,6 +316,11 @@ std::vector<at::Tensor> mha_fwd_nl(const at::Tensor &qkv, // total x num ...@@ -304,6 +316,11 @@ std::vector<at::Tensor> mha_fwd_nl(const at::Tensor &qkv, // total x num
auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts); auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts);
if( zero_tensors ) {
ctx.zero_();
s.zero_();
}
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator()); auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());
Fused_multihead_attention_fprop_params params; Fused_multihead_attention_fprop_params params;
...@@ -344,7 +361,8 @@ std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num ...@@ -344,7 +361,8 @@ std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
const at::Tensor &cu_seqlens, // b+1 const at::Tensor &cu_seqlens, // b+1
const float p_dropout, // probability to drop const float p_dropout, // probability to drop
const int max_seq_len // max sequence length to choose the kernel const int max_seq_len, // max sequence length to choose the kernel
const bool zero_tensors
) { ) {
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
...@@ -378,6 +396,10 @@ std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num ...@@ -378,6 +396,10 @@ std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num
auto dqkv = torch::empty_like(qkv); auto dqkv = torch::empty_like(qkv);
if( zero_tensors ) {
dqkv.zero_();
}
int num_chunks = 2; int num_chunks = 2;
if( batch_size == 1 ) { if( batch_size == 1 ) {
num_chunks = 4; num_chunks = 4;
......
...@@ -32,16 +32,17 @@ import fmhalib as mha ...@@ -32,16 +32,17 @@ import fmhalib as mha
class FMHAFun(torch.autograd.Function): class FMHAFun(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, cu_seqlens, p_dropout, max_s, is_training): def forward(ctx, qkv, cu_seqlens, p_dropout, max_s, is_training, zero_tensors):
batch_size = cu_seqlens.numel() - 1 batch_size = cu_seqlens.numel() - 1
if batch_size < 4: if batch_size < 4:
context, S_dmask = mha.fwd_nl(qkv, cu_seqlens, p_dropout, max_s, is_training, None) context, S_dmask = mha.fwd_nl(qkv, cu_seqlens, p_dropout, max_s, is_training, zero_tensors, None)
else: else:
context, S_dmask = mha.fwd(qkv, cu_seqlens, p_dropout, max_s, is_training, None) context, S_dmask = mha.fwd(qkv, cu_seqlens, p_dropout, max_s, is_training, zero_tensors, None)
ctx.save_for_backward(qkv, S_dmask) ctx.save_for_backward(qkv, S_dmask)
ctx.cu_seqlens = cu_seqlens ctx.cu_seqlens = cu_seqlens
ctx.p_dropout = p_dropout ctx.p_dropout = p_dropout
ctx.max_s = max_s ctx.max_s = max_s
ctx.zero_tensors = zero_tensors
return context return context
@staticmethod @staticmethod
...@@ -49,11 +50,11 @@ class FMHAFun(torch.autograd.Function): ...@@ -49,11 +50,11 @@ class FMHAFun(torch.autograd.Function):
qkv, S_dmask = ctx.saved_tensors qkv, S_dmask = ctx.saved_tensors
batch_size = ctx.cu_seqlens.numel() - 1 batch_size = ctx.cu_seqlens.numel() - 1
if batch_size < 4: if batch_size < 4:
dqkv, dp, _ = mha.bwd_nl(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s) dqkv, dp, _ = mha.bwd_nl(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s, ctx.zero_tensors)
else: else:
dqkv, dp = mha.bwd(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s) dqkv, dp = mha.bwd(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s, ctx.zero_tensors)
return dqkv, None, None, None, None, None, None return dqkv, None, None, None, None, None, None, None
class FMHA(torch.nn.Module): class FMHA(torch.nn.Module):
...@@ -67,8 +68,8 @@ class FMHA(torch.nn.Module): ...@@ -67,8 +68,8 @@ class FMHA(torch.nn.Module):
self.d = self.hidden_size // self.h self.d = self.hidden_size // self.h
assert self.d * self.h == self.hidden_size, "Invalid hidden size/num_heads" assert self.d * self.h == self.hidden_size, "Invalid hidden size/num_heads"
def forward(self, qkv, cu_seqlens, max_s, is_training=True): def forward(self, qkv, cu_seqlens, max_s, is_training=True, zero_tensors=False):
ctx = FMHAFun.apply(qkv.view(-1, 3, self.h, self.d), cu_seqlens, self.p_dropout, max_s, is_training) ctx = FMHAFun.apply(qkv.view(-1, 3, self.h, self.d), cu_seqlens, self.p_dropout, max_s, is_training, zero_tensors)
return ctx.view(-1, self.hidden_size) return ctx.view(-1, self.hidden_size)
...@@ -51,8 +51,8 @@ def py_mha(qkv, amask, b, s, h, d): ...@@ -51,8 +51,8 @@ def py_mha(qkv, amask, b, s, h, d):
class TestFMHA(unittest.TestCase): class TestFMHA(unittest.TestCase):
def run_test(self, s, b): def run_test(self, s: int, b: int, zero_tensors: bool):
print(f'Test s={s} b={b}') print(f'Test s={s} b={b}, zero_tensors={zero_tensors}')
torch.manual_seed(1234) torch.manual_seed(1234)
torch.cuda.manual_seed(1234) torch.cuda.manual_seed(1234)
...@@ -77,9 +77,9 @@ class TestFMHA(unittest.TestCase): ...@@ -77,9 +77,9 @@ class TestFMHA(unittest.TestCase):
qkv.requires_grad = True qkv.requires_grad = True
if b < 4: if b < 4:
ctx, S_ = mha.fwd_nl(qkv_vs, cu_seqlens, 0.0, s, True, None) ctx, S_ = mha.fwd_nl(qkv_vs, cu_seqlens, 0.0, s, True, zero_tensors, None)
else: else:
ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, 0.0, s, True, None) ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, 0.0, s, True, zero_tensors, None)
ctx = ctx.view(b,s,h,d) ctx = ctx.view(b,s,h,d)
ctx_ref = py_mha(qkv, amask, b,s,h,d) ctx_ref = py_mha(qkv, amask, b,s,h,d)
...@@ -95,27 +95,34 @@ class TestFMHA(unittest.TestCase): ...@@ -95,27 +95,34 @@ class TestFMHA(unittest.TestCase):
dw2 = dw.permute(0,2,1,3).clone().detach().contiguous() dw2 = dw.permute(0,2,1,3).clone().detach().contiguous()
if b < 4: if b < 4:
dqkv2, _, _ = mha.bwd_nl(dw2, qkv_vs, S_, cu_seqlens, 0.0, s) dqkv2, _, _ = mha.bwd_nl(dw2, qkv_vs, S_, cu_seqlens, 0.0, s, zero_tensors)
else: else:
dqkv2, _ = mha.bwd(dw2, qkv_vs, S_, cu_seqlens, 0.0, s) dqkv2, _ = mha.bwd(dw2, qkv_vs, S_, cu_seqlens, 0.0, s, zero_tensors)
dqkv2 = dqkv2.permute(0,2,1,3).view(b,s, h,3,d) dqkv2 = dqkv2.permute(0,2,1,3).view(b,s, h,3,d)
self.assertTrue(torch.allclose(qkv.grad.float(), dqkv2.float(), atol=1e-3)) self.assertTrue(torch.allclose(qkv.grad.float(), dqkv2.float(), atol=1e-3))
def test_128(self): def test_128(self):
self.run_test(128, 32) self.run_test(128, 32, False)
self.run_test(128, 32, True)
def test_256(self): def test_256(self):
self.run_test(256, 32) self.run_test(256, 32, False)
self.run_test(256, 32, True)
def test_384(self): def test_384(self):
self.run_test(384, 32) self.run_test(384, 32, False)
self.run_test(384, 32, True)
def test_512(self): def test_512(self):
self.run_test(512, 32) self.run_test(512, 32, False)
self.run_test(512, 2) self.run_test(512, 32, True)
self.run_test(512, 3) self.run_test(512, 2, False)
self.run_test(512, 2, True)
self.run_test(512, 3, False)
self.run_test(512, 3, True)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment