Unverified Commit b4b6e903 authored by Aman Gupta Karmani's avatar Aman Gupta Karmani Committed by GitHub
Browse files

add benchmark for xformers fa2 wrapper (#492)

parent 45ba93cd
...@@ -79,7 +79,8 @@ dropout_p = 0.0 ...@@ -79,7 +79,8 @@ dropout_p = 0.0
methods = (["Flash2", "Pytorch"] methods = (["Flash2", "Pytorch"]
+ (["Triton"] if attention_triton is not None else []) + (["Triton"] if attention_triton is not None else [])
+ (["xformers"] if xops is not None else [])) + (["xformers.c"] if xops is not None else [])
+ (["xformers.f"] if xops is not None else []))
time_f = {} time_f = {}
time_b = {} time_b = {}
...@@ -139,8 +140,19 @@ for causal in causal_vals: ...@@ -139,8 +140,19 @@ for causal in causal_vals:
attn_bias=xops.LowerTriangularMask() if causal else None, attn_bias=xops.LowerTriangularMask() if causal else None,
op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp) op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp)
) )
time_f[config, "xformers"] = f time_f[config, "xformers.c"] = f
time_b[config, "xformers"] = b time_b[config, "xformers.c"] = b
if xops is not None:
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
requires_grad=True) for _ in range(3)]
f, b = time_fwd_bwd(
xops.memory_efficient_attention, q, k, v,
attn_bias=xops.LowerTriangularMask() if causal else None,
op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp)
)
time_f[config, "xformers.f"] = f
time_b[config, "xformers.f"] = b
print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###") print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
for method in methods: for method in methods:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment