Unverified Commit 1403c21a authored by Kevin Stephano's avatar Kevin Stephano Committed by GitHub
Browse files

Remove legacy fuser usage from multihead attention in contrib in favor of the...

Remove legacy fuser usage from multihead attention in contrib in favor of the default which should be nvfuser.  Modify test scripts to activate fusion. (#1403)
parent 5ffb22d0
......@@ -10,12 +10,6 @@ from .fast_encdec_multihead_attn_func import fast_encdec_attn_func
from .fast_encdec_multihead_attn_norm_add_func import fast_encdec_attn_norm_add_func
from apex.normalization.fused_layer_norm import FusedLayerNorm
if hasattr(torch._C, "_jit_set_profiling_executor"):
torch._C._jit_set_profiling_executor(False)
if hasattr(torch._C, "_jit_set_profiling_mode"):
torch._C._jit_set_profiling_mode(False)
@torch.jit.script
def jit_dropout_add(x, residual, prob, is_training):
# type: (Tensor, Tensor, float, bool) -> Tensor
......
......@@ -10,12 +10,6 @@ from .fast_self_multihead_attn_func import fast_self_attn_func
from .fast_self_multihead_attn_norm_add_func import fast_self_attn_norm_add_func
from apex.normalization.fused_layer_norm import FusedLayerNorm
if hasattr(torch._C, "_jit_set_profiling_executor"):
torch._C._jit_set_profiling_executor(False)
if hasattr(torch._C, "_jit_set_profiling_mode"):
torch._C._jit_set_profiling_mode(False)
@torch.jit.script
def jit_dropout_add(x, residual, prob, is_training):
# type: (Tensor, Tensor, float, bool) -> Tensor
......
......@@ -48,25 +48,26 @@ class EncdecMultiheadAttnNormAddTest(unittest.TestCase):
def test_encdec_multihead_attn_norm_add(self) :
grads = torch.randn_like(self.tst_inputs_q)
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q,
self.ref_inputs_k,
self.ref_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q,
self.tst_inputs_k,
self.tst_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
self.ref_inputs_q.backward(grads)
self.tst_inputs_q.backward(grads)
for _ in range(5) :
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q,
self.ref_inputs_k,
self.ref_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q,
self.tst_inputs_k,
self.tst_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
self.ref_inputs_q.backward(grads)
self.tst_inputs_q.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5))
......
......@@ -45,24 +45,25 @@ class SelfMultiheadAttnNormAddTest(unittest.TestCase):
def test_self_multihead_attn_norm_add(self) :
grads = torch.randn_like(self.tst_inputs)
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs,
self.ref_inputs,
self.ref_inputs,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs,
self.tst_inputs,
self.tst_inputs,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
self.ref_inputs.backward(grads)
self.tst_inputs.backward(grads)
for _ in range(0, 5) :
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs,
self.ref_inputs,
self.ref_inputs,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs,
self.tst_inputs,
self.tst_inputs,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
self.ref_inputs.backward(grads)
self.tst_inputs.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment