Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
781a9ec8
Commit
781a9ec8
authored
Jan 07, 2026
by
fengzch
Browse files
fix: add call_fa_mha_fwd
parent
45ccfe64
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
97 additions
and
40 deletions
+97
-40
src/FluxModel.cpp
src/FluxModel.cpp
+80
-23
src/SanaModel.cpp
src/SanaModel.cpp
+17
-17
No files found.
src/FluxModel.cpp
View file @
781a9ec8
...
@@ -2,19 +2,76 @@
...
@@ -2,19 +2,76 @@
#include "kernels/misc_kernels.h"
#include "kernels/misc_kernels.h"
#include "kernels/gemm_batched.h"
#include "kernels/gemm_batched.h"
#include "kernels/zgemm/zgemm.h"
#include "kernels/zgemm/zgemm.h"
#include "flash_api.h"
#include "activation.h"
#include "activation.h"
#include "Tensor.h"
// #include <nvtx3/nvToolsExt.h>
// #include <nvtx3/nvToolsExt.h>
#include <roctx.h>
#include <roctx.h>
#include <pybind11/functional.h>
#include <pybind11/functional.h>
#include <flash_c_api.h>
#include <iostream>
#include <iostream>
using
spdlog
::
fmt_lib
::
format
;
using
spdlog
::
fmt_lib
::
format
;
using
namespace
nunchaku
;
using
namespace
nunchaku
;
Tensor
call_fa_mha_fwd
(
Tensor
&
q
,
// batch_size x seqlen_q x num_heads x head_size
Tensor
&
k
,
// batch_size x seqlen_k x num_heads_k x head_size
Tensor
&
v
,
// batch_size x seqlen_k x num_heads_k x head_size
// c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
// c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const
float
p_dropout
,
const
float
softmax_scale
,
bool
is_causal
,
int
window_size_left
,
int
window_size_right
,
const
bool
return_softmax
// c10::optional<at::Generator> gen_
)
{
// printf("LOG(INFO) %s: %d %s\n", __FILE__, __LINE__, __func__);
Tensor
o
=
Tensor
::
empty_like
(
q
);
size_t
workspace_size
=
mha_fwd_workspace
(
q
.
shape
[
0
],
q
.
shape
[
1
],
k
.
shape
[
1
],
q
.
shape
[
2
],
k
.
shape
[
2
],
q
.
shape
[
3
],
k
.
shape
[
3
],
false
);
const
Device
device
=
q
.
device
();
Tensor
workspace
=
Tensor
::
allocate
({
1
,
1
,
1
,
(
int
)
workspace_size
},
Tensor
::
INT8
,
device
);
mha_fwd
(
q
.
data_ptr
(),
k
.
data_ptr
(),
v
.
data_ptr
(),
o
.
data_ptr
(),
nullptr
,
//* alibi
nullptr
,
//* rng_state
workspace
.
data_ptr
(),
//* workspace
q
.
shape
[
0
],
q
.
shape
[
1
],
k
.
shape
[
1
],
//* sizes
q
.
shape
[
2
],
k
.
shape
[
2
],
q
.
shape
[
3
],
k
.
shape
[
3
],
q
.
stride
(
0
),
q
.
stride
(
1
),
q
.
stride
(
2
),
q
.
stride
(
3
),
//* q strides
k
.
stride
(
0
),
k
.
stride
(
1
),
k
.
stride
(
2
),
k
.
stride
(
3
),
//* k strides
v
.
stride
(
0
),
v
.
stride
(
1
),
v
.
stride
(
2
),
v
.
stride
(
3
),
//* v strides
o
.
stride
(
0
),
o
.
stride
(
1
),
o
.
stride
(
2
),
o
.
stride
(
3
),
//* o strides
1
,
1
,
//* alibi strides
p_dropout
,
//* p_dropout
softmax_scale
,
//* softmax_scale
is_causal
,
//* is_causal
window_size_left
,
window_size_right
,
//* window sizes
0.0
f
,
//* softcap
return_softmax
,
//* return_softmax
0
,
//* seed
q
.
scalar_type
()
==
Tensor
::
ScalarType
::
BF16
,
//* is_bf16
false
//* is_bhsd
);
return
o
;
}
Tensor
forward_mlp
(
GEMM_W4A4
&
fc1
,
GEMM_W4A4
&
fc2
,
Tensor
norm_hidden_states
)
{
Tensor
forward_mlp
(
GEMM_W4A4
&
fc1
,
GEMM_W4A4
&
fc2
,
Tensor
norm_hidden_states
)
{
std
::
cout
<<
"Called forward_mlp "
<<
std
::
endl
;
Tensor
ff_output
=
fc2
.
forward_quant
(
std
::
get
<
GEMM_W4A4
::
QuantizedActivation
>
(
Tensor
ff_output
=
fc2
.
forward_quant
(
std
::
get
<
GEMM_W4A4
::
QuantizedActivation
>
(
fc1
.
forward
(
norm_hidden_states
,
GEMM_W4A4
::
FuseOptions
::
GELU_QUANT
,
&
fc2
)));
fc1
.
forward
(
norm_hidden_states
,
GEMM_W4A4
::
FuseOptions
::
GELU_QUANT
,
&
fc2
)));
return
ff_output
;
return
ff_output
;
...
@@ -118,7 +175,7 @@ Tensor Attention::forward(Tensor qkv) {
...
@@ -118,7 +175,7 @@ Tensor Attention::forward(Tensor qkv) {
Tensor
k
=
reshaped
.
slice
(
2
,
num_heads
,
num_heads
*
2
);
Tensor
k
=
reshaped
.
slice
(
2
,
num_heads
,
num_heads
*
2
);
Tensor
v
=
reshaped
.
slice
(
2
,
num_heads
*
2
,
num_heads
*
3
);
Tensor
v
=
reshaped
.
slice
(
2
,
num_heads
*
2
,
num_heads
*
3
);
Tensor
raw_attn_output
=
mha_fwd
(
q
,
k
,
v
,
0.0
f
,
pow
(
q
.
shape
[
-
1
],
(
-
0.5
)),
false
,
-
1
,
-
1
,
false
)
.
front
()
;
Tensor
raw_attn_output
=
call_fa_
mha_fwd
(
q
,
k
,
v
,
0.0
f
,
pow
(
q
.
shape
[
-
1
],
(
-
0.5
)),
false
,
-
1
,
-
1
,
false
);
assert
(
raw_attn_output
.
shape
[
0
]
==
batch_size
);
assert
(
raw_attn_output
.
shape
[
0
]
==
batch_size
);
assert
(
raw_attn_output
.
shape
[
1
]
==
num_tokens
);
assert
(
raw_attn_output
.
shape
[
1
]
==
num_tokens
);
...
@@ -201,27 +258,27 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
...
@@ -201,27 +258,27 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
spdlog
::
debug
(
"q,k,v={}"
,
q
.
shape
.
str
());
spdlog
::
debug
(
"q,k,v={}"
,
q
.
shape
.
str
());
Tensor
raw_attn_output
=
mha_fwd_block
(
q
,
//
Tensor raw_attn_output = mha_fwd_block(q,
k
,
//
k,
v
,
//
v,
cu_seqlens
,
//
cu_seqlens,
cu_seqlens
,
//
cu_seqlens,
POOL_SIZE
,
//
POOL_SIZE,
POOL_SIZE
,
//
POOL_SIZE,
headmask_type
,
//
headmask_type,
{},
//
{},
blockmask
,
//
blockmask,
num_tokens
,
//
num_tokens,
num_tokens
,
//
num_tokens,
0.0
f
,
//
0.0f,
pow
(
q
.
shape
[
-
1
],
(
-
0.5
)),
//
pow(q.shape[-1], (-0.5)),
false
,
//
false,
false
,
//
false,
false
,
//
false,
-
1
,
//
-1,
-
1
)
//
-1)
.
front
();
//
.front();
std
::
cout
<<
"mha_fwd_block not support !!!"
<<
std
::
endl
;
debug
(
"raw_attn_output"
,
raw_attn_output
);
debug
(
"raw_attn_output"
,
raw_attn_output
);
if
(
cast_fp16
)
{
if
(
cast_fp16
)
{
...
...
src/SanaModel.cpp
View file @
781a9ec8
...
@@ -164,23 +164,23 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
...
@@ -164,23 +164,23 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
Tensor
k
=
kv
.
slice
(
1
,
0
,
num_heads
);
Tensor
k
=
kv
.
slice
(
1
,
0
,
num_heads
);
Tensor
v
=
kv
.
slice
(
1
,
num_heads
,
num_heads
*
2
);
Tensor
v
=
kv
.
slice
(
1
,
num_heads
,
num_heads
*
2
);
Tensor
attn_output
=
mha_varlen_fwd
(
q
,
//
Tensor attn_output = mha_varlen_fwd(q,
k
,
//
k,
v
,
//
v,
cu_seqlens_img
,
//
cu_seqlens_img,
cu_seqlens_txt
,
//
cu_seqlens_txt,
num_tokens_img
,
//
num_tokens_img,
num_tokens_txt
,
//
num_tokens_txt,
0.0
f
,
//
0.0f,
pow
(
q
.
shape
[
-
1
],
(
-
0.5
)),
//
pow(q.shape[-1], (-0.5)),
false
,
//
false,
false
,
//
false,
-
1
,
//
-1,
-
1
,
//
-1,
false
)
//
false)
.
front
()
//
.front()
.
view
({
batch_size
,
num_tokens_img
,
num_heads
*
head_dim
});
//
.view({batch_size, num_tokens_img, num_heads * head_dim});
std
::
cout
<<
"mha_varlen_fwd not support !!!"
<<
std
::
endl
;
// Tensor attn_output = mha_fwd(q, k, v,
// Tensor attn_output = mha_fwd(q, k, v,
// 0.0f,
// 0.0f,
// pow(q.shape[-1], (-0.5)),
// pow(q.shape[-1], (-0.5)),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment