Commit 781a9ec8 authored by fengzch's avatar fengzch
Browse files

fix: add call_fa_mha_fwd

parent 45ccfe64
...@@ -2,19 +2,76 @@ ...@@ -2,19 +2,76 @@
#include "kernels/misc_kernels.h" #include "kernels/misc_kernels.h"
#include "kernels/gemm_batched.h" #include "kernels/gemm_batched.h"
#include "kernels/zgemm/zgemm.h" #include "kernels/zgemm/zgemm.h"
#include "flash_api.h"
#include "activation.h" #include "activation.h"
#include "Tensor.h"
// #include <nvtx3/nvToolsExt.h> // #include <nvtx3/nvToolsExt.h>
#include <roctx.h> #include <roctx.h>
#include <pybind11/functional.h> #include <pybind11/functional.h>
#include <flash_c_api.h>
#include <iostream> #include <iostream>
using spdlog::fmt_lib::format; using spdlog::fmt_lib::format;
using namespace nunchaku; using namespace nunchaku;
Tensor call_fa_mha_fwd(Tensor &q, // batch_size x seqlen_q x num_heads x head_size
Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
// c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
// c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const bool return_softmax
// c10::optional<at::Generator> gen_
) {
// printf("LOG(INFO) %s: %d %s\n", __FILE__, __LINE__, __func__);
Tensor o = Tensor::empty_like(q);
size_t workspace_size = mha_fwd_workspace(
q.shape[0], q.shape[1], k.shape[1],
q.shape[2], k.shape[2],
q.shape[3], k.shape[3],
false
);
const Device device = q.device();
Tensor workspace = Tensor::allocate({1, 1, 1, (int)workspace_size}, Tensor::INT8, device);
mha_fwd(
q.data_ptr(), k.data_ptr(), v.data_ptr(), o.data_ptr(),
nullptr, //* alibi
nullptr, //* rng_state
workspace.data_ptr(), //* workspace
q.shape[0], q.shape[1], k.shape[1], //* sizes
q.shape[2], k.shape[2],
q.shape[3], k.shape[3],
q.stride(0), q.stride(1), q.stride(2), q.stride(3), //* q strides
k.stride(0), k.stride(1), k.stride(2), k.stride(3), //* k strides
v.stride(0), v.stride(1), v.stride(2), v.stride(3), //* v strides
o.stride(0), o.stride(1), o.stride(2), o.stride(3), //* o strides
1, 1, //* alibi strides
p_dropout, //* p_dropout
softmax_scale, //* softmax_scale
is_causal, //* is_causal
window_size_left,
window_size_right, //* window sizes
0.0f, //* softcap
return_softmax, //* return_softmax
0, //* seed
q.scalar_type() == Tensor::ScalarType::BF16, //* is_bf16
false //* is_bhsd
);
return o;
}
Tensor forward_mlp(GEMM_W4A4 &fc1, GEMM_W4A4 &fc2, Tensor norm_hidden_states) { Tensor forward_mlp(GEMM_W4A4 &fc1, GEMM_W4A4 &fc2, Tensor norm_hidden_states) {
std::cout << "Called forward_mlp " << std::endl;
Tensor ff_output = fc2.forward_quant(std::get<GEMM_W4A4::QuantizedActivation>( Tensor ff_output = fc2.forward_quant(std::get<GEMM_W4A4::QuantizedActivation>(
fc1.forward(norm_hidden_states, GEMM_W4A4::FuseOptions::GELU_QUANT, &fc2))); fc1.forward(norm_hidden_states, GEMM_W4A4::FuseOptions::GELU_QUANT, &fc2)));
return ff_output; return ff_output;
...@@ -118,7 +175,7 @@ Tensor Attention::forward(Tensor qkv) { ...@@ -118,7 +175,7 @@ Tensor Attention::forward(Tensor qkv) {
Tensor k = reshaped.slice(2, num_heads, num_heads * 2); Tensor k = reshaped.slice(2, num_heads, num_heads * 2);
Tensor v = reshaped.slice(2, num_heads * 2, num_heads * 3); Tensor v = reshaped.slice(2, num_heads * 2, num_heads * 3);
Tensor raw_attn_output = mha_fwd(q, k, v, 0.0f, pow(q.shape[-1], (-0.5)), false, -1, -1, false).front(); Tensor raw_attn_output = call_fa_mha_fwd(q, k, v, 0.0f, pow(q.shape[-1], (-0.5)), false, -1, -1, false);
assert(raw_attn_output.shape[0] == batch_size); assert(raw_attn_output.shape[0] == batch_size);
assert(raw_attn_output.shape[1] == num_tokens); assert(raw_attn_output.shape[1] == num_tokens);
...@@ -201,27 +258,27 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) { ...@@ -201,27 +258,27 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
spdlog::debug("q,k,v={}", q.shape.str()); spdlog::debug("q,k,v={}", q.shape.str());
Tensor raw_attn_output = mha_fwd_block(q, // Tensor raw_attn_output = mha_fwd_block(q,
k, // k,
v, // v,
cu_seqlens, // cu_seqlens,
cu_seqlens, // cu_seqlens,
POOL_SIZE, // POOL_SIZE,
POOL_SIZE, // POOL_SIZE,
headmask_type, // headmask_type,
{}, // {},
blockmask, // blockmask,
num_tokens, // num_tokens,
num_tokens, // num_tokens,
0.0f, // 0.0f,
pow(q.shape[-1], (-0.5)), // pow(q.shape[-1], (-0.5)),
false, // false,
false, // false,
false, // false,
-1, // -1,
-1) // -1)
.front(); // .front();
std::cout << "mha_fwd_block not support !!!" << std::endl;
debug("raw_attn_output", raw_attn_output); debug("raw_attn_output", raw_attn_output);
if (cast_fp16) { if (cast_fp16) {
......
...@@ -164,23 +164,23 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens ...@@ -164,23 +164,23 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
Tensor k = kv.slice(1, 0, num_heads); Tensor k = kv.slice(1, 0, num_heads);
Tensor v = kv.slice(1, num_heads, num_heads * 2); Tensor v = kv.slice(1, num_heads, num_heads * 2);
Tensor attn_output = mha_varlen_fwd(q, // Tensor attn_output = mha_varlen_fwd(q,
k, // k,
v, // v,
cu_seqlens_img, // cu_seqlens_img,
cu_seqlens_txt, // cu_seqlens_txt,
num_tokens_img, // num_tokens_img,
num_tokens_txt, // num_tokens_txt,
0.0f, // 0.0f,
pow(q.shape[-1], (-0.5)), // pow(q.shape[-1], (-0.5)),
false, // false,
false, // false,
-1, // -1,
-1, // -1,
false) // false)
.front() // .front()
.view({batch_size, num_tokens_img, num_heads * head_dim}); // .view({batch_size, num_tokens_img, num_heads * head_dim});
std::cout << "mha_varlen_fwd not support !!!" << std::endl;
// Tensor attn_output = mha_fwd(q, k, v, // Tensor attn_output = mha_fwd(q, k, v,
// 0.0f, // 0.0f,
// pow(q.shape[-1], (-0.5)), // pow(q.shape[-1], (-0.5)),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment