fromldm.modules.x_transformerimportEncoder,TransformerWrapper# TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
importopen_clip
fromldm.utilimportdefault,count_params
classAbstractEncoder(nn.Module):
...
...
@@ -20,189 +16,149 @@ class AbstractEncoder(nn.Module):
This is a Triton implementation of the Flash Attention algorithm
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton)