Commit d8e379b5 authored by fengzch's avatar fengzch
Browse files

fix: mha_varlen_fwd not support

parent 4e690109
...@@ -180,7 +180,7 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens ...@@ -180,7 +180,7 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
// false) // false)
// .front() // .front()
// .view({batch_size, num_tokens_img, num_heads * head_dim}); // .view({batch_size, num_tokens_img, num_heads * head_dim});
Tensor attn_output = Tensor::ones({batch_size * num_tokens, num_heads, dim_head}, Tensor::FP16, Device::cuda()); Tensor attn_output = Tensor::ones({batch_size, num_tokens_img, num_heads * head_dim}, Tensor::FP16, Device::cuda());
std::cout << "mha_varlen_fwd not support !!!" << std::endl; std::cout << "mha_varlen_fwd not support !!!" << std::endl;
// Tensor attn_output = mha_fwd(q, k, v, // Tensor attn_output = mha_fwd(q, k, v,
// 0.0f, // 0.0f,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment