Commit 83b7542d authored by sxtyzhangzk's avatar sxtyzhangzk Committed by Zhekai Zhang
Browse files

Fix resolution issue in flashattn2

parent bf3669dd
......@@ -118,6 +118,33 @@ Attention::Attention(int num_heads, int dim_head, Device device) :
headmask_type = headmask_type.copy(device);
}
Tensor Attention::forward(Tensor qkv) {
assert(qkv.ndims() == 3);
const Device device = qkv.device();
const int batch_size = qkv.shape[0];
const int num_tokens = qkv.shape[1];
assert(qkv.shape[2] == num_heads * dim_head * 3);
Tensor reshaped = qkv.view({batch_size, num_tokens, num_heads * 3, dim_head});
Tensor q = reshaped.slice(2, 0, num_heads);
Tensor k = reshaped.slice(2, num_heads, num_heads * 2);
Tensor v = reshaped.slice(2, num_heads * 2, num_heads * 3);
Tensor raw_attn_output = mha_fwd(q, k, v,
0.0f,
pow(q.shape[-1], (-0.5)),
false, -1, -1, false
).front();
assert(raw_attn_output.shape[0] == batch_size);
assert(raw_attn_output.shape[1] == num_tokens);
assert(raw_attn_output.shape[2] == num_heads);
assert(raw_attn_output.shape[3] == dim_head);
return raw_attn_output.view({batch_size * num_tokens, num_heads, dim_head});
}
Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
const bool cast_fp16 = this->force_fp16 && qkv.scalar_type() != Tensor::FP16;
......@@ -312,7 +339,8 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
debug("qkv", qkv);
// Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
attn_output = attn.forward(qkv, {}, 0);
// attn_output = attn.forward(qkv, {}, 0);
attn_output = attn.forward(qkv);
attn_output = attn_output.reshape({batch_size, num_tokens, num_heads * dim_head});
} else if (attnImpl == AttentionImpl::NunchakuFP16) {
assert(batch_size == 1);
......@@ -501,7 +529,11 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
nvtxRangePushA("Attention");
raw_attn_output = attn.forward(concat, pool, sparsityRatio);
if (pool.valid()) {
raw_attn_output = attn.forward(concat, pool, sparsityRatio);
} else {
raw_attn_output = attn.forward(concat);
}
nvtxRangePop();
......
......@@ -63,6 +63,7 @@ public:
static constexpr int POOL_SIZE = 128;
Attention(int num_heads, int dim_head, Device device);
Tensor forward(Tensor qkv);
Tensor forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio);
static void setForceFP16(Module *module, bool value);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment