Commit 4d5fa31c authored by Bei Wang's avatar Bei Wang
Browse files

turn off use_memory_efficient_kernel off only for fp16 in primitives.py

parent aef97f4b
......@@ -479,7 +479,9 @@ class Attention(nn.Module):
q, k, v = self._prep_qkv(q_x, kv_x)
# [*, Q, H, C_hidden]
use_memory_efficient_kernel = False
float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16)
if float16_enabled:
use_memory_efficient_kernel = False
if(use_memory_efficient_kernel):
if(len(biases) > 2):
raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment